zsh:1: command not found: pip
Big Data Analytics using Apache Spark
CN6022 - Big Data Infrastructure and Manipulation
1 Introduction
1.1 Project Overview
This project aims to demonstrate advanced Big Data manipulation and analytics using Apache Spark SQL. While the module provided a baseline dataset (Air Flight Status), we have elected to utilize the ESA Gaia Data Release 3 (DR3) for this analysis. Gaia is widely considered the largest and most complex astronomical catalog in human history, containing astrometric and photometric data for over 1.8 billion sources.
1.2 Dataset Selection and Justification
The core requirement for this coursework was to utilize a dataset that exceeds the volume and complexity of the provided sample. The Gaia DR3 dataset fits this criterion perfectly for three reasons:
- Volume: Even a 1% subset of Gaia (approximately 3 million rows) significantly exceeds the size of the standard flight dataset, requiring distributed computing techniques to process efficiently.
- Complexity: Unlike the flat structure of flight logs, astronomical data requires complex feature engineering (e.g., calculating Absolute Magnitude from Parallax).
- Scientific Relevance: This dataset allows for genuine astrophysical discovery including the identification of White Dwarfs and Binary Systems.
1.3 Data Acquisition Strategy: The “Two-Tier” Architecture
Ingesting the entire 1.8 billion row catalog is computationally infeasible for this project’s scope. Furthermore, a simple random sample introduces Malmquist Bias, where bright, distant stars drown out faint, local objects. To resolve this, we architected a Two-Tier Data Strategy, acquiring two distinct datasets via the ESA Archive (ADQL):
- Dataset A: The “Galactic Survey” (Macro-Analysis): A random 1% sample of the entire sky (~3 million rows). This “Deep Field” dataset is used to map the broad structure of the Milky Way and analyze general stellar demographics, breadth over precision
- Dataset B: The “Local Bubble” (Micro-Physics): A volume-limited sample of all stars within 100 parsecs (\(distance \le 100pc\)). This high-precision dataset eliminates distance-related noise, allowing us to detect faint objects like White Dwarfs that would otherwise be invisible, precision over breadth.
1.3.1 Data Schema and Key Columns
The following columns from the Gaia dataset will be used in our analysis:
source_id: Unique identifier for each star. (64-bit Integer)ra: Right Ascension (celestial longitude). (Double Precision)dec: Declination (celestial latitude). (Double Precision)parallax: Parallax in milliarcseconds, used to calculate distance (\(d = 1/p\)). (Double Precision)parallax_error: The uncertainty in the parallax measurement. (Single Precision)pmra: Proper motion in the direction of Right Ascension. (Double Precision)pmdec: Proper motion in the direction of Declination. (Double Precision)phot_g_mean_mag: Mean apparent magnitude in the G-band (a measure of brightness as seen from Earth). (Single Precision)bp_rp: The blue-red color index, a proxy for the star’s surface temperature. (Single Precision)teff_gspphot: Effective temperature of the star’s photosphere, derived from photometry. (Single Precision)
1.4 Team Structure and Objectives
The analysis is divided into three distinct workstreams, each focusing on a different aspect of the data:
- Jasmi (Stellar Demographics): Focuses on classifying star populations (H-R Diagram) and identifying high-velocity outliers using the Galactic Survey.
- Yogi (Galactic Structure): detailed mapping of the Milky Way’s density and analysis of measurement error rates across the sky.
- Jayrup (Exotic Star Hunting): Utilizes the high-precision “Local Bubble” and “Galactic Survey” data to detect rare stellar remnants and gravitationally bound binary star systems.
1.5 Understanding the Data
1.5.1 Installing dependencies
1.5.2 Downloading the Datasets
Code
from astroquery.gaia import Gaia
import os
output_dir = "../data"
os.makedirs(output_dir, exist_ok=True)
def save_strict_parquet(results, filename):
"""
Converts Astropy Table to Pandas with strict Gaia Data Model types.
"""
df = results.to_pandas()
# 1. Enforce Source_ID as 64-bit Integer (Long)
df['source_id'] = df['source_id'].astype('int64')
# 2. Enforce Double Precision (float64) for Angles/Velocity
doubles = ['ra', 'dec', 'parallax', 'pmra', 'pmdec']
for col in doubles:
if col in df.columns:
df[col] = df[col].astype('float64')
# 3. Enforce Single Precision (float32) for Errors/Magnitudes
# This saves 50% RAM on these columns vs standard floats.
floats = ['parallax_error', 'bp_rp', 'phot_g_mean_mag','teff_gspphot']
for col in floats:
if col in df.columns:
df[col] = df[col].astype('float32')
# Save
print(f">> Saving {len(df)} rows to {filename}...")
df.to_parquet(filename, index=False)
# --- JOB 1: SURVEY ---
survey_file = os.path.join(output_dir, "gaia_survey.parquet")
if not os.path.exists(survey_file):
print(">> Downloading Survey...")
q = """
SELECT source_id, ra, dec, parallax, parallax_error, pmra, pmdec,
phot_g_mean_mag, bp_rp, teff_gspphot
FROM gaiadr3.gaia_source
WHERE parallax > 0 AND phot_g_mean_mag < 19 AND random_index < 3000000
"""
job = Gaia.launch_job_async(q)
save_strict_parquet(job.get_results(), survey_file)
else:
print(">> Survey already downloaded.")
# --- JOB 2: LOCAL BUBBLE ---
local_file = os.path.join(output_dir, "gaia_100pc.parquet")
if not os.path.exists(local_file):
print(">> Downloading Local Bubble...")
q = """
SELECT source_id, ra, dec, parallax, parallax_error, pmra, pmdec,
phot_g_mean_mag, bp_rp, teff_gspphot
FROM gaiadr3.gaia_source
WHERE parallax >= 10 AND parallax_over_error > 5
"""
job = Gaia.launch_job_async(q)
save_strict_parquet(job.get_results(), local_file)
else:
print(">> Local Bubble already downloaded.")
print(">> Done.")>> Survey already downloaded.
>> Local Bubble already downloaded.
>> Done.
1.5.3 Exploring the Datasets
Code
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count
# Initialize Spark
spark = SparkSession.builder \
.appName("Gaia_Data_Exploration") \
.config("spark.driver.memory", "4g") \
.getOrCreate()
# Load Datasets
survey_path = "../data/gaia_survey.parquet"
local_path = "../data/gaia_100pc.parquet"
df_survey = spark.read.parquet(survey_path)
df_local = spark.read.parquet(local_path)
print(">>> DATASET 1: GALACTIC SURVEY (Macro)")
print(f"Total Rows: {df_survey.count():,}")
df_survey.printSchema()
print("\n>>> DATASET 2: LOCAL BUBBLE (Micro)")
print(f"Total Rows: {df_local.count():,}")
df_local.printSchema()
# ====================================================
# 1. PHYSICAL COMPARISON
# ====================================================
print("\n>>> STATISTICAL COMPARISON: PARALLAX (Distance)")
print("Note: Distance (pc) is approx 1000 / parallax.")
print("-- Survey Dataset Stats --")
df_survey.select("parallax", "phot_g_mean_mag", "pmra").describe().show()
print("-- Local Bubble Stats --")
df_local.select("parallax", "phot_g_mean_mag", "pmra").describe().show()
# ====================================================
# 2. QUALITY CHECK (Null Analysis)
# ====================================================
def count_nulls(df):
return df.select([count(when(col(c).isNull(), c)).alias(c) for c in df.columns])
# We only check critical columns
from pyspark.sql.functions import when
cols_to_check = ["parallax", "pmra", "teff_gspphot", "bp_rp"]
print("\n>>> NULL VALUE ANALYSIS (Survey Dataset)")
df_survey.select([count(when(col(c).isNull(), c)).alias(c) for c in cols_to_check]).show()
print("\n>>> NULL VALUE ANALYSIS (Local Dataset)")
df_local.select([count(when(col(c).isNull(), c)).alias(c) for c in cols_to_check]).show()WARNING: Using incubator modules: jdk.incubator.vector
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
>>> DATASET 1: GALACTIC SURVEY (Macro)
Total Rows: 844,868
root
|-- source_id: long (nullable = true)
|-- ra: double (nullable = true)
|-- dec: double (nullable = true)
|-- parallax: double (nullable = true)
|-- parallax_error: float (nullable = true)
|-- pmra: double (nullable = true)
|-- pmdec: double (nullable = true)
|-- phot_g_mean_mag: float (nullable = true)
|-- bp_rp: float (nullable = true)
|-- teff_gspphot: float (nullable = true)
>>> DATASET 2: LOCAL BUBBLE (Micro)
Total Rows: 541,958
root
|-- source_id: long (nullable = true)
|-- ra: double (nullable = true)
|-- dec: double (nullable = true)
|-- parallax: double (nullable = true)
|-- parallax_error: float (nullable = true)
|-- pmra: double (nullable = true)
|-- pmdec: double (nullable = true)
|-- phot_g_mean_mag: float (nullable = true)
|-- bp_rp: float (nullable = true)
|-- teff_gspphot: float (nullable = true)
>>> STATISTICAL COMPARISON: PARALLAX (Distance)
Note: Distance (pc) is approx 1000 / parallax.
-- Survey Dataset Stats --
+-------+--------------------+------------------+-------------------+
|summary| parallax| phot_g_mean_mag| pmra|
+-------+--------------------+------------------+-------------------+
| count| 844868| 844868| 844868|
| mean| 0.55071253992443|17.404869147894882| -2.373963360005405|
| stddev| 0.7697707648724794|1.4361723326971574| 7.0091972840050865|
| min|1.868718209532133E-7| 3.0238004|-377.74180694687686|
| max| 75.56887569382528| 18.999998| 649.0319386167508|
+-------+--------------------+------------------+-------------------+
-- Local Bubble Stats --
+-------+------------------+-----------------+------------------+
|summary| parallax| phot_g_mean_mag| pmra|
+-------+------------------+-----------------+------------------+
| count| 541958| 540859| 541958|
| mean|14.283010896409225|16.94108547991132|-3.085397329690918|
| stddev|7.0432588218793475|3.605457797042056| 94.7899342646246|
| min|10.000005410606198| 1.9425238|-4406.469178827325|
| max| 768.0665391873573| 21.289928| 6765.995136250774|
+-------+------------------+-----------------+------------------+
>>> NULL VALUE ANALYSIS (Survey Dataset)
+--------+----+------------+-----+
|parallax|pmra|teff_gspphot|bp_rp|
+--------+----+------------+-----+
| 0| 0| 125194|23080|
+--------+----+------------+-----+
>>> NULL VALUE ANALYSIS (Local Dataset)
+--------+----+------------+-----+
|parallax|pmra|teff_gspphot|bp_rp|
+--------+----+------------+-----+
| 0| 0| 423880|68171|
+--------+----+------------+-----+
2 Queries
2.1 Jasmi
2.1.1 The H-R Diagram
To prepare a high-quality, filtered dataset by calculating the intrinsic luminosity (Absolute Magnitude, \(M_G\)) for every star in the survey, providing the necessary data for the H-R Diagram’s Y-axis.
Methodology
This phase uses a Direct Calculation and Strict Quality Filtering approach in Spark SQL. The complexity lies not in aggregation, but in applying the critical astronomical transformation formula and enforcing a high Signal-to-Noise Ratio (SNR) on the distance data before transferring the results to Python for advanced plotting.
Parameter Justification
To ensure the resulting H-R Diagram is scientifically precise—preventing the smearing of the Main Sequence—strict parameters were justified and applied:
Absolute Magnitude (\(M_G\)) Calculation The luminosity is calculated using the standard formula :
\[ M_G = m - 5 \log_{10}(d) + 5 \]
where \(m\) is the apparent magnitude (
phot_g_mean_mag) and \(d\) is the distance in parsecs (derived from \(\frac{1000}{\text{parallax}}\)). Using \(M_G\) is mandatory because the H-R diagram plots luminosity, which is independent of Earth’s viewing distance.Colour Index (
bp_rp) This direct photometric measurement provides the most reliable measure of the star’s effective surface temperature (the X-axis of the H-R Diagram). It is directly selected for the plot without modification.Quality Cut (Parallax / Parallax Error \(\geq 5.0\)) The signal-to-noise ratio (SNR) of the parallax is calculated dynamically and filtered to values greater than 5.0. This Critical SNR Cut is the most important filter: it ensures that only stars with highly reliable distance estimates are processed. Stars with poor parallax data would otherwise cause the Main Sequence to appear thick and indistinct, masking key stellar populations.
The Query Logic
The analysis was performed using a single, efficient query that focused purely on mathematical transformation and data exclusion:
Derivation The Absolute Magnitude (\(\text{abs\_mag}\)) was calculated directly in the
SELECTclause using the \(\log_{10}\) function applied to the parallax.Sanitisation The
WHEREclause performed strict filtering to remove all low-quality and non-physical measurements:- Exclusion of stars with invalid distance (
parallax > 0). - Exclusion of stars missing essential photometry (
bp_rp IS NOT NULLandphot\_g\_mean\_mag IS NOT NULL). - Exclusion of all data points failing the SNR \(\geq 5.0\) standard.
- Exclusion of stars with invalid distance (
Data Transfer The final output was checked for total row count and then converted from a high-performance Spark DataFrame (
raw_df) to a standard Pandas DataFrame (pdf). This transfer is necessary to facilitate the advanced plotting capabilities of Matplotlib and NumPy.
Code
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, log10, lit, count, when, sqrt, pow, percentile_approx
from pyspark.sql import functions as F
from pyspark.sql.window import Window
import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import colors
spark = SparkSession.builder \
.appName("Gaia_HR_Analysis") \
.config("spark.driver.memory", "4g") \
.getOrCreate()
# Loading Data
df_survey = spark.read.parquet("../data/gaia_survey.parquet")
# Creating a temp view
df_survey.createOrReplaceTempView("gaia_survey")
#Query
query_raw = """
SELECT
source_id,
bp_rp,
-- Calculate Absolute Magnitude (M) directly in Spark
-- Formula: Apparent Mag - 5 * log10(Distance) + 5
(phot_g_mean_mag - 5 * LOG10(1000 / parallax) + 5) AS abs_mag
FROM gaia_survey
WHERE
parallax > 0
AND (parallax / parallax_error) >= 5.0 -- For better precision
AND bp_rp IS NOT NULL
AND phot_g_mean_mag IS NOT NULL
"""
print(">>> RUNNING QUERY...")
raw_df = spark.sql(query_raw)
# Check how many stars to plot
print(f"Total stars to plot: {raw_df.count():,}")
# Convert to Pandas
pdf = raw_df.toPandas()
print("Data loaded.")
# checking the top 5 rows
print("\n------------------------------\nPrinting first five rows\n------------------------------")
pdf.head()25/12/17 07:39:11 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.
>>> RUNNING QUERY...
Total stars to plot: 296,983
Data loaded.
------------------------------
Printing first five rows
------------------------------
| source_id | bp_rp | abs_mag | |
|---|---|---|---|
| 0 | 242682832483968 | 2.051584 | 9.188913 |
| 1 | 5474605834008192 | 0.437714 | 12.840361 |
| 2 | 6591404705437056 | 1.121918 | 6.156537 |
| 3 | 6640539131235840 | 2.003047 | 6.955578 |
| 4 | 7473968945265152 | 1.275984 | 5.436462 |
Visualisation Logic
The final visualization step—performed in Python using the data prepared by this query—is a significant improvement over simple heatmapping:
- Advanced Density Scatter: Instead of using fixed SQL bins, the data is plotted using a density-mapped scatter plot.
- NumPy Density Trick: NumPy’s
histogram2dfunction is used to assign a density score to every individual point. - Visual Enhancement: The data is sorted by this density score and plotted with a logarithmic colour scale (
cmap='inferno'). This technique ensures that the densest region (the Main Sequence core) is plotted last and remains sharp, while also making faint, low-density features (like the White Dwarf sequence) clearly visible.
Code
nbins = 300
k = colors.LogNorm()
x = pdf['bp_rp'].values
y = pdf['abs_mag'].values
plt.style.use('dark_background') # Dark background makes the colors pop
# Create the grid
H, xedges, yedges = np.histogram2d(x, y, bins=nbins)
# Map every star to its bin
x_inds = np.clip(np.digitize(x, xedges) - 1, 0, nbins - 1)
y_inds = np.clip(np.digitize(y, yedges) - 1, 0, nbins - 1)
# Assign density value to each star
z = H[x_inds, y_inds]
# 3. Sort for Sharpness
# Sort the data so the densest (brightest) regions are plotted LAST (on top)
idx = z.argsort()
x, y, z = x[idx], y[idx], z[idx]
# 4. Plot: Use 'magma' colormap for clearer density progression
plt.figure(figsize=(8, 10))
# 'magma' or 'plasma' provide excellent contrast for astronomical density plots
plt.scatter(x, y, c=z, s=0.5, cmap='magma', norm=k, alpha=1.0, edgecolors='none')
# 5. Astronomy Polish
plt.title("Gaia DR3 Hertzsprung-Russell Diagram", fontsize=16)
plt.xlabel("Color Index (BP-RP)", fontsize=12)
plt.ylabel("Absolute Magnitude (M)", fontsize=12)
# Invert Y-Axis (Bright stars go at the top)
plt.ylim(17, -5)
plt.xlim(-1, 5)
# Add Colorbar
cbar = plt.colorbar()
cbar.set_label('Star Density', rotation=270, labelpad=20)
# Annotations (FIXED: Changed colour and added bolding for prominence)
plt.text(0.5, 14, 'White Dwarfs', color='white', fontsize=12, fontweight='bold')
plt.text(1.5, 4, 'Main Sequence', color='gold', fontsize=12, rotation=-45, fontweight='bold')
plt.text(2.5, 0, 'Giants', color='orange', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()2.1.2 Stellar Population Census & Astrophysical Classification
To classify the stellar population into major spectral types (O through M) and to distinguish between Main Sequence stars (dwarfs) and evolved stars (giants).
Methodology
This query uses a Colour–Magnitude approach rather than relying on raw temperature estimates. This method provides more reliable classification, particularly for faint sources where temperature estimations can become inaccurate.
Dataset
Gaia DR3 Source (Random 3 Million Object Sample)
Parameter Justification
To ensure scientific accuracy and robustness in the classification process, the following parameters were selected:
Colour Index (
bp_rp)
The colour index is calculated as the difference between the Blue (bp) and Red (rp) photometric bands. This metric is used instead of the algorithmically derived effective temperature (teff_gspphot). Temperature estimates tend to degrade for faint stars, often producing artificially high temperature values. In contrast, the colour index is a direct photometric measurement and provides a more reliable separation of stellar spectral types.Quality Cut (Parallax / Parallax Error > 5)
The signal-to-noise ratio of the parallax is calculated dynamically and filtered to values greater than 5. This ensures that only stars with reliable distance measurements are included in the analysis. Applying this threshold reduces distance-related uncertainties and prevents the misclassification of dwarfs as giants due to parallax errors, commonly referred to as “ghost giants”.Absolute Magnitude (\(M_G\))
The intrinsic brightness of each star is calculated using the formula:
[ M_G = m - 5 _{10}(d) + 5 ]
where ( m ) is the apparent magnitude and ( d ) is the distance in parsecs. Absolute magnitude allows for a clear distinction between stellar evolutionary stages. Red giants are identified as intrinsically bright stars (( M_G < 3.5 )), while red dwarfs are significantly dimmer (( M_G > 3.5 )). This separation cannot be achieved using apparent magnitude alone.
Code
# 1. The Corrected Query (Calculates error ratio on the fly)
query = """
WITH PhysicalProperties AS (
SELECT
source_id,
bp_rp,
-- Calculate Absolute Magnitude (Mg)
phot_g_mean_mag - 5 * LOG10(1000 / parallax) + 5 AS abs_mag_g,
-- FIX: Calculate the Signal-to-Noise ratio manually
(parallax / parallax_error) AS calculated_poe
FROM gaia_survey
WHERE parallax > 0
AND parallax_error > 0 -- Prevent divide-by-zero errors
),
FilteredStars AS (
SELECT * FROM PhysicalProperties
WHERE calculated_poe > 5 -- Quality Cut: Only keep reliable data
),
ClassifiedStars AS (
SELECT
-- 1. Spectral Classification (Colour)
CASE
WHEN bp_rp IS NULL THEN 'Unknown'
WHEN bp_rp > 1.6 THEN 'Cool (M)'
WHEN bp_rp > 0.9 THEN 'Warm (K)'
WHEN bp_rp > 0.7 THEN 'Yellow (G)'
WHEN bp_rp > 0.4 THEN 'Yel-Wht (F)'
WHEN bp_rp > 0.0 THEN 'White (A)'
ELSE 'Hot (O/B)'
END AS spectral_type,
-- 2. Luminosity Classification (Giant vs Dwarf)
CASE
WHEN abs_mag_g < 3.5 AND bp_rp > 0.7 THEN 'Giant'
ELSE 'Main Sequence (Dwarf)'
END AS luminosity_class
FROM FilteredStars
)
SELECT
spectral_type,
COUNT(CASE WHEN luminosity_class = 'Main Sequence (Dwarf)' THEN 1 END) AS dwarf_count,
COUNT(CASE WHEN luminosity_class = 'Giant' THEN 1 END) AS giant_count,
COUNT(*) AS total_count,
ROUND(100.0 * COUNT(CASE WHEN luminosity_class = 'Giant' THEN 1 END) / COUNT(*), 1) AS giant_percentage
FROM ClassifiedStars
WHERE spectral_type != 'Unknown'
GROUP BY spectral_type
ORDER BY
CASE spectral_type
WHEN 'Hot (O/B)' THEN 1
WHEN 'White (A)' THEN 2
WHEN 'Yel-Wht (F)' THEN 3
WHEN 'Yellow (G)' THEN 4
WHEN 'Warm (K)' THEN 5
WHEN 'Cool (M)' THEN 6
ELSE 7
END
"""
# 2. Execute
df_results = spark.sql(query)
# 3. Show Results
print(">> Census Results: Dwarfs vs. Giants (calculated_poe fix)")
df_results.show(truncate=False)>> Census Results: Dwarfs vs. Giants (calculated_poe fix)
+-------------+-----------+-----------+-----------+----------------+
|spectral_type|dwarf_count|giant_count|total_count|giant_percentage|
+-------------+-----------+-----------+-----------+----------------+
|Hot (O/B) |94 |0 |94 |0.0 |
|White (A) |745 |0 |745 |0.0 |
|Yel-Wht (F) |4536 |0 |4536 |0.0 |
|Yellow (G) |21349 |6598 |27947 |23.6 |
|Warm (K) |143669 |21220 |164889 |12.9 |
|Cool (M) |84014 |14758 |98772 |14.9 |
+-------------+-----------+-----------+-----------+----------------+
Visualisation Logic
To clearly present the results of the analysis, two visualisations were produced:
- Population Mix (Stacked Bar Chart)
A 100% stacked bar chart was used to illustrate the proportion of giants versus dwarfs within each spectral type. This visualisation highlights the evolutionary composition of the stellar population and allows for direct comparison across spectral classes.
Code
# 1. Convert your Spark results to Pandas
# (Assuming df_results is the dataframe from your last query)
pdf = df_results.toPandas()
# 2. Setup the Data for Plotting
# We need to sort the data from Hot to Cool for the X-axis
sort_order = {
'Hot (O/B)': 0, 'White (A)': 1, 'Yel-Wht (F)': 2,
'Yellow (G)': 3, 'Warm (K)': 4, 'Cool (M)': 5
}
pdf['sort_id'] = pdf['spectral_type'].map(sort_order)
pdf = pdf.sort_values('sort_id')
# Calculate Percentages for the Stacked Bar
# (We re-calculate here to ensure they sum to exactly 100 for the plot)
pdf['dwarf_pct'] = (pdf['dwarf_count'] / pdf['total_count']) * 100
pdf['giant_pct'] = (pdf['giant_count'] / pdf['total_count']) * 100
# 3. Create the Plot
fig, ax = plt.subplots(figsize=(10, 6))
# Plot Dwarfs (Bottom Bar)
p1 = ax.bar(pdf['spectral_type'], pdf['dwarf_pct'], label='Main Sequence (Dwarfs)',
color='#1f77b4', edgecolor='black', alpha=0.9)
# Plot Giants (Top Bar)
p2 = ax.bar(pdf['spectral_type'], pdf['giant_pct'], bottom=pdf['dwarf_pct'],
label='Giants (Evolved)', color='#d62728', edgecolor='black', alpha=0.9)
# 4. Styling
ax.set_title('Stellar Population Mix: Dwarfs vs. Giants (High-Quality Subset)', fontsize=16)
ax.set_ylabel('Percentage of Population (%)', fontsize=12)
ax.set_xlabel('Spectral Type', fontsize=12)
ax.set_ylim(0, 100)
ax.legend(loc='upper left', frameon=True)
ax.grid(axis='y', linestyle='--', alpha=0.4)
# 5. Add Labels
# Label the Giants if they exist
for i, (idx, row) in enumerate(pdf.iterrows()):
if row['giant_pct'] > 1:
ax.text(i, row['dwarf_pct'] + row['giant_pct']/2, f"{row['giant_pct']:.1f}%",
ha='center', va='center', color='white', fontweight='bold', fontsize=11)
# Label the Dwarfs
if row['dwarf_pct'] > 5:
ax.text(i, row['dwarf_pct']/2, f"{row['dwarf_pct']:.1f}%",
ha='center', va='center', color='white', fontweight='bold', fontsize=11)
plt.style.use('dark_background') # Dark background makes the colors pop
plt.tight_layout()
plt.show()Interpretation of Results
O, B, A, and F Spectral Types
The analysis identified 0.0% giant stars within these spectral classes. This result is consistent with stellar evolution theory, as massive, blue stars evolve rapidly and do not remain in the blue region of the spectrum once they leave the Main Sequence.G and K Spectral Types
A substantial giant population was observed within these classes, accounting for approximately 13–24% of the sample. This correctly traces the Red Giant Branch and distinguishes evolved stars, such as Arcturus, from nearby solar-type dwarfs.M Spectral Type
The M-type population was overwhelmingly dominated by dwarfs. This reflects the high abundance of red dwarfs in the galaxy and the relative rarity of true M-type giants in a randomly selected stellar sample.
2.1.3 High-Velocity Outlier Detection (Kinematic Analysis)
To identify and flag the top 1% of stars in the df_local dataset exhibiting the highest Total Proper Motion (apparent speed across the sky). These stars are often key kinematic outliers, such as halo stars or nearby high-velocity dwarfs.
Columns Needed:
pmra(Proper Motion Right Ascension),pmdec(Proper Motion Declination).SQL Complexity: Simplified and Optimised. The executed code avoids the slow, complex nested SQL window function (
PERCENT_RANK()) proposed in the original plan, replacing it with a direct, single-action calculation in Spark.- Mathematical Transformation: Total Proper Motion (\(\mu\)) is calculated using the Pythagorean theorem: \(\mu = \sqrt{\mu_{\alpha}^2 + \mu_{\delta}^2}\).
- Code:
sqrt(pow(col("pmra"), 2) + pow(col("pmdec"), 2))
- Code:
- Threshold Calculation: The complexity is offloaded to the optimized Spark function
percentile_approx(). This directly computes the 99th percentile proper motion value (pm_threshold) in a single, fast aggregation.- Code:
df_motion.agg(percentile_approx("total_pm", lit(0.99)))
- Code:
- Flagging (Simple Filter): The final SQL query becomes a simple filter applied to the
pm_threshold, avoiding a subquery and expensive ranking.- Code:
CASE WHEN total_pm >= {pm_threshold} THEN 1 ELSE 0 END AS is_high_pm
- Code:
- Mathematical Transformation: Total Proper Motion (\(\mu\)) is calculated using the Pythagorean theorem: \(\mu = \sqrt{\mu_{\alpha}^2 + \mu_{\delta}^2}\).
Code
# ====================================================
# 1. Calculate Total Proper Motion & Find the Threshold
# ====================================================
# Calculate total proper motion (total_pm = sqrt(pmra^2 + pmdec^2))
df_motion = df_local.withColumn(
"total_pm",
sqrt(pow(col("pmra"), 2) + pow(col("pmdec"), 2))
)
# Find the 99th percentile (Top 1%) proper motion value
# This value will be our threshold (e.g., 100 mas/yr)
pm_threshold = df_motion.agg(
percentile_approx("total_pm", lit(0.99)).alias("threshold_value")
).collect()[0]["threshold_value"]
print(f">>> Calculated 99th Percentile Proper Motion Threshold: {pm_threshold:.2f} mas/yr")
# ====================================================
# 2. SQL Query: Prepare Data for Plotting
# ====================================================
# Create a temporary view for the motion-enhanced DataFrame
df_motion.createOrReplaceTempView("gaia_motion")
# The query selects necessary fields and flags the fast-moving stars
plot_query_with_motion = f"""
SELECT
source_id,
bp_rp,
-- Calculate Absolute Magnitude (Mg)
(phot_g_mean_mag - 5 * LOG10(1000 / parallax) + 5) AS abs_mag_g,
total_pm,
-- Flag if the star is in the top 1% of motion
CASE
WHEN total_pm >= {pm_threshold} THEN 1
ELSE 0
END AS is_high_pm
FROM gaia_motion
WHERE parallax > 0
AND (parallax / parallax_error) > 5 -- The Quality Cut
AND phot_g_mean_mag IS NOT NULL
"""
# Execute the query
df_plot_motion = spark.sql(plot_query_with_motion)
# Convert to Pandas for plotting
pdf_motion = df_plot_motion.toPandas()
print(f"Total stars prepared for plotting: {len(pdf_motion):,}")
print(f"Total high-PM stars flagged: {pdf_motion['is_high_pm'].sum():,}")>>> Calculated 99th Percentile Proper Motion Threshold: 450.85 mas/yr
Total stars prepared for plotting: 540,859
Total high-PM stars flagged: 5,435
Visualization
A Layered Scatter Plot on the H-R Diagram was used (as seen in the executed code). This is a highly effective scientific visualisation that goes beyond the simple table proposed in the original plan. It plots: * Background: The bulk (slow-moving) population (purple). * Foreground: The high-velocity outliers (is_high_pm = 1) in a distinct colour (cyan), showing their position relative to the main stellar sequences.
Code
import matplotlib.pyplot as plt
# Filter the data into two subsets
pdf_slow = pdf_motion[pdf_motion['is_high_pm'] == 0]
pdf_fast = pdf_motion[pdf_motion['is_high_pm'] == 1]
# 1. Create the Plot Canvas
plt.figure(figsize=(8, 10))
plt.style.use('dark_background')
# 2. Plot the Bulk Population (Slow/Main Sequence)
# Use a high alpha (low transparency) colour to show the main density
plt.scatter(
pdf_slow['bp_rp'],
pdf_slow['abs_mag_g'],
c='purple', # Base colour
s=0.5, # Small size
alpha=0.2, # Very transparent to show density variation
edgecolors='none',
label='Bulk Population (Low PM)'
)
# 3. OVERLAY the High-Velocity Stars
# Use a distinct, bright colour and larger size
plt.scatter(
pdf_fast['bp_rp'],
pdf_fast['abs_mag_g'],
c='cyan', # Stand-out colour
s=5, # Much larger size
alpha=1.0, # Fully opaque
edgecolors='none',
label=f'High-Velocity Outliers (Top 1% > {pm_threshold:.0f} mas/yr)'
)
# 4. Polish and Axes
plt.gca().invert_yaxis() # Brighter stars (lower mag) go at the top
plt.title('HR Diagram Highlighting Kinematic Outliers (Top 1% Proper Motion)', fontsize=16)
plt.xlabel('Colour Index (BP-RP)')
plt.ylabel('Absolute Magnitude (M)')
plt.xlim(-1, 5)
plt.ylim(17, -5)
plt.legend(loc='upper right')
plt.grid(alpha=0.1)
plt.style.use('dark_background') # Dark background makes the colors pop
plt.show()
spark.stop()2.2 Yogi
2.2.1 Galactic Plane vs Halo
Code
from pyspark.sql import SparkSession
#from pyspark.sql.functions import *
import pyspark.sql.functions as F
from pyspark.sql.window import Window
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
# Initialize Spark
spark = SparkSession.builder \
.appName("Member2_Galactic_Structure") \
.config("spark.driver.memory", "4g") \
.getOrCreate()
# 1. Load the Data
parquet_path = "../data/gaia_survey.parquet"
df = spark.read.parquet(parquet_path)
# describe the gaia_survey data
df.describe().show()25/12/17 07:39:22 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
+-------+--------------------+--------------------+-------------------+--------------------+-------------------+-------------------+-------------------+------------------+------------------+------------------+
|summary| source_id| ra| dec| parallax| parallax_error| pmra| pmdec| phot_g_mean_mag| bp_rp| teff_gspphot|
+-------+--------------------+--------------------+-------------------+--------------------+-------------------+-------------------+-------------------+------------------+------------------+------------------+
| count| 844868| 844868| 844868| 844868| 844868| 844868| 844868| 844868| 821788| 719674|
| mean|4.208655544660601...| 220.3738280138546|-14.661386665029692| 0.55071253992443|0.12661974095717216| -2.373963360005405|-3.0263002891849347|17.404869147894882|1.5300547396294117| 4878.39501996751|
| stddev|1.770170248935241...| 84.67372412021707| 38.84441601829229| 0.7697707648724794|0.08883984995128799| 7.0091972840050865| 6.990431755441395|1.4361723326971574|0.6075873126388648|1028.3654349007377|
| min| 42159399217024|2.620729015336566...| -89.92247730288476|1.868718209532133E-7| 0.007789557|-377.74180694687686| -565.3362997837297| 3.0238004| -0.54805183| 2739.997|
| max| 6917515734718201472| 359.9982847741013| 89.77689025849239| 75.56887569382528| 1.519277| 649.0319386167508| 340.6398385516719| 18.999998| 7.822592| 37489.88|
+-------+--------------------+--------------------+-------------------+--------------------+-------------------+-------------------+-------------------+------------------+------------------+------------------+
Code
# 2. Register as a Temp View
# This allows us to use SQL commands on the dataframe 'df'
df.createOrReplaceTempView("gaia_source")
print(">>> Executing Query 2.1: Galactic Plane vs Halo...")
# 3. The Spark SQL Query
# We use a CTE (Common Table Expression) named 'CalculatedData' to do the math first,
# and then the main query does the aggregation.
query = """
WITH CalculatedData AS (
SELECT
source_id,
-- Calculate Total Motion (Hypotenuse of pmra and pmdec)
SQRT(POW(pmra, 2) + POW(pmdec, 2)) AS total_motion,
-- Define Region using the 'DEC Proxy' method
CASE
WHEN ABS(dec) < 15 THEN 'Galactic Plane'
ELSE 'Galactic Halo'
END AS region
FROM gaia_source
)
-- Final Aggregation
SELECT
region,
ROUND(AVG(total_motion), 2) AS avg_speed,
ROUND(STDDEV(total_motion), 2) AS stddev_speed,
COUNT(*) AS star_count
FROM CalculatedData
GROUP BY region
ORDER BY avg_speed DESC
"""
# 4. Run the query
sql_results = spark.sql(query)
# 5. Show results
sql_results.show()>>> Executing Query 2.1: Galactic Plane vs Halo...
+--------------+---------+------------+----------+
| region|avg_speed|stddev_speed|star_count|
+--------------+---------+------------+----------+
| Galactic Halo| 7.03| 7.73| 691873|
|Galactic Plane| 6.96| 8.96| 152995|
+--------------+---------+------------+----------+
Analysis
The objective of this query was to identify stellar kinematics by comparing the proper motion of stars in the dense Galactic Disk and Galactic Halo.
- Metric: We calculated the “Total Proper Motion” (\(\mu\)) for each star by combining its two components: \(\mu = \sqrt{\texttt{pmra}^2 + \texttt{pmdec}^2}\).
- Segmentation: Due to dataset constraints, we utilized Declination (dec) as a proxy for Galactic Latitude. We defined the “Galactic Plane” as the equatorial band (\(|dec| < 15^{\circ}\)) and the “Halo” as the high-latitude regions (\(|dec| \ge 15^{\circ}\)).
Finding A: The Problem “Missed Galaxy”(Star Count)
- the simple Reason is that the Milky Way is tilted.
- Our query only sees the flat strip across the middle (dec between -15 and 15 degrees).Because of that our “Flat strip” missed the biggest parts of the galaxy.
Finding B: The Velocity Variation
- We Expected the “Halo” to have a higher velocity than the “Plane”.Instead the “Disk” have the biggest range (8.96 vs 7.73).
The simple reason is the distance changes how speed looks.
The Disk: Contains many stars that are close to Earth. Because they are close, their speeds look dramatic and varied to our camera.
The Halo: Stars are incredibly far away. Even if they are moving fast, their distance makes them all appear to move slowly and steadily, leading to a “lower” measurement.
Visualization
Code
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
# 1. Take a random 1% sample (Crucial for performance)
pdf_subset = df.select("ra", "dec").sample(fraction=0.01, seed=42).toPandas()
# 2. Re-create the Region logic (using numpy to avoid errors)
pdf_subset['Region'] = pdf_subset['dec'].apply(lambda x: 'Galactic Plane' if np.abs(x) < 15 else 'Galactic Halo')
# 3. SET THE THEME: Dark Background for a "Space" look
plt.style.use('dark_background')
plt.figure(figsize=(12, 7))
# 4. Plot the "Halo" (Background Stars)
# We plot these first in cool blue so they look "distant"
sns.scatterplot(
data=pdf_subset[pdf_subset['Region'] == 'Galactic Halo'],
x='ra',
y='dec',
color='cornflowerblue',
s=5, # Small dots
alpha=0.3, # Faint transparency
edgecolor=None,
label='Galactic Halo (Sparse)'
)
# 5. Plot the "Plane" (The Disk)
# We plot these on top in bright Gold to represent the dense star field
sns.scatterplot(
data=pdf_subset[pdf_subset['Region'] == 'Galactic Plane'],
x='ra',
y='dec',
color='#FFD700', # Gold color
s=10, # Slightly larger dots
alpha=0.4, # Brighter
edgecolor=None,
label='Galactic Plane (Dense)'
)
# Draw the cut-off lines
plt.axhline(15, color='white', linestyle='--', linewidth=1, alpha=0.5)
plt.axhline(-15, color='white', linestyle='--', linewidth=1, alpha=0.5)
# Add text labels on the graph
plt.text(180, 0, "Milky Way Disk\n(High Density)", color='orangered',
ha='center', va='center', fontsize=12, )
plt.text(180, 60, "Galactic Halo\n(Low Density)", color='cornflowerblue',
ha='center', va='center', fontsize=10)
# 7. Final Polish
plt.title("Spatial Structure: The 'Flat' Disk vs. The 'Round' Halo", fontsize=14, color='white')
plt.xlabel("Right Ascension (Longitude)", fontsize=12)
plt.ylabel("Declination (Latitude)", fontsize=12)
plt.legend(loc='upper right', facecolor='black', edgecolor='white')
plt.grid(False) # Turn off grid to look more like space
# Astronomers view the sky looking "up", so we invert the X-axis
plt.gca().invert_xaxis()
plt.show()2.2.2 Star Density Sky Map
Code
print(">>> Executing Query 2.2: Star Density Sky Map")
query_density = """
WITH DensityBins AS (
SELECT
-- 1. Spatial Binning (The 'Grid')
-- We divide by 2, floor it to remove decimals, then multiply by 2
-- This snaps every star to the nearest even number grid line (0, 2, 4...)
FLOOR(ra / 2) * 2 AS ra_bin,
FLOOR(dec / 2) * 2 AS dec_bin,
-- 2. Aggregation (Counting stars in that grid square)
COUNT(*) AS star_count
FROM gaia_source
GROUP BY 1, 2 -- Group by the first two columns (ra_bin, dec_bin)
),
RankedRegions AS (
SELECT
*,
-- Rank the bins from most populated (1) to least populated
RANK() OVER (ORDER BY star_count DESC) as density_rank
FROM DensityBins
)
-- 4. Final Result: Top 5 Densest Regions
SELECT * FROM RankedRegions
WHERE density_rank <= 5
"""
# 3. Run and Show
sql_density_results = spark.sql(query_density)
sql_density_results.show()>>> Executing Query 2.2: Star Density Sky Map
25/12/17 07:39:24 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/17 07:39:24 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/17 07:39:24 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/17 07:39:24 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/17 07:39:24 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/17 07:39:25 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/17 07:39:25 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+------+-------+----------+------------+
|ra_bin|dec_bin|star_count|density_rank|
+------+-------+----------+------------+
| 270| -30| 2135| 1|
| 272| -30| 2096| 2|
| 268| -30| 2036| 3|
| 272| -28| 2011| 4|
| 274| -28| 1762| 5|
+------+-------+----------+------------+
Analysis
The goal of this query was to create the “Stellar Density Map” to identify the most populated region in the sky.
- Spatial Binning Strategy: We have utilized a qantization approach to divide the continuous sky into discrete grid.
FLOOR(coordinate / 2) * 2to both Right Ascension (ra) and Declination (dec), we grouped stars into \(2^{\circ} \times 2^{\circ}\) spatial bins. - Analytical Complexity: To Rank this density, we have incorporated the pyspark Window Function (
RANK() OVER (ORDER BY star_count DESC)).
Critical Analysis:
- The results has the perfect validation of the binning algorithm and identified Galactic centre.
- Astronomical Validation: The coordinates returned (\(RA \approx 270^{\circ}\), \(Dec \approx -30^{\circ}\)) correspond precisely to the constellation Sagittarius.
- The official coordinates of Sagittarius A* (the supermassive black hole at the center of the Milky Way)re \(RA \approx 266^{\circ}\), \(Dec \approx -29^{\circ}\).
Interpretation:
The high star count in these bins conform that confirm that we are looking through the galactic plate and the centre bulge.
Visualization
Code
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from pyspark.sql.functions import col
# 1. Reuse the "DensityBins" logic from Query 2.2, but save it as a real Dataframe
density_bins = spark.sql("""
SELECT
FLOOR(ra / 2) * 2 AS ra_bin,
FLOOR(dec / 2) * 2 AS dec_bin,
COUNT(*) as local_density
FROM gaia_source
GROUP BY 1, 2
""")
# We give every star a new property: "How many neighbors do I have?"
# We calculate the bins on the fly for the join
df_with_density = df.withColumn("ra_bin", F.floor(col("ra") / 2) * 2) \
.withColumn("dec_bin", F.floor(col("dec") / 2) * 2) \
.join(density_bins, ["ra_bin", "dec_bin"])
# This is enough to look like "every star" to the human eye without crashing.
pdf_visual = df_with_density.sample(fraction=0.05, seed=42).select("ra", "dec", "local_density").toPandas()
# 4. The "Glowing" Scatter Plot
plt.style.use('default')
plt.figure(figsize=(14, 8))
# We sort by density so the bright stars are plotted ON TOP of the dark ones
pdf_visual = pdf_visual.sort_values("local_density")
scatter = plt.scatter(
pdf_visual['ra'],
pdf_visual['dec'],
c=pdf_visual['local_density'], # Color by density
cmap='magma', # Magma/Inferno = glowing fire effect
s=2, # Tiny dots
alpha=0.8, # High opacity to make them pop
edgecolors='none' # No borders
)
# 5. Add a Color Bar (Legend)
cbar = plt.colorbar(scatter)
cbar.set_label('Stellar Density (Stars per bin)', rotation=270, labelpad=20, color='Black')
cbar.ax.yaxis.set_tick_params(color='Black')
plt.setp(plt.getp(cbar.ax.axes, 'yticklabels'), color='Black')
# 6. Styling
plt.title("The Milky Way: Stellar Density Visualization", fontsize=18, color='Black' )
plt.xlabel("Right Ascension", fontsize=12, color='Black')
plt.ylabel("Declination", fontsize=12, color='Black')
plt.gca().invert_xaxis() # Astronomical standard
plt.grid(False)
plt.show()2.2.3 Parallax Error vs Brightness
Code
print(">>> Executing Query 2.3: Parallax Error vs. Brightness...")
query_quality = """
WITH QualityMetrics AS (
SELECT
-- 1. Create Brightness Bins (0-21)
-- FLOOR groups '10.2' and '10.9' into bin '10'
FLOOR(phot_g_mean_mag) AS mag_bin,
-- Pass through the raw data we need
parallax_error,
parallax
FROM gaia_source
WHERE parallax > 0 -- Filter out bad data to avoid division by zero
)
-- 2. Aggregation to find the Error Trend
SELECT
mag_bin AS magnitude,
-- A. Count how many stars are in this brightness range
COUNT(*) AS star_count,
-- B. Average Absolute Error (Raw uncertainty)
ROUND(AVG(parallax_error), 4) AS avg_raw_error,
-- C. Average Relative Error (The "Percentage" uncertainty)
-- Formula: Error / Total Signal
ROUND(AVG(parallax_error / parallax), 4) AS avg_relative_error
FROM QualityMetrics
WHERE mag_bin > 0 AND mag_bin < 22 -- Focus on the valid main range
GROUP BY mag_bin
ORDER BY mag_bin ASC -- Sort from Bright -> Dim
"""
# 3. Run and Show
quality_results = spark.sql(query_quality)
quality_results.show(25) # Show 25 rows to see the full range>>> Executing Query 2.3: Parallax Error vs. Brightness...
+---------+----------+-------------+------------------+
|magnitude|star_count|avg_raw_error|avg_relative_error|
+---------+----------+-------------+------------------+
| 3| 1| 0.1317| 0.0438|
| 4| 1| 0.0749| 0.0146|
| 5| 7| 0.0586| 0.023|
| 6| 27| 0.0426| 0.0125|
| 7| 61| 0.0357| 0.013|
| 8| 197| 0.0401| 0.0207|
| 9| 488| 0.0263| 0.017|
| 10| 1304| 0.0264| 0.0229|
| 11| 3007| 0.0265| 0.0299|
| 12| 7054| 0.0246| 0.0487|
| 13| 15489| 0.0239| 0.0655|
| 14| 32688| 0.0297| 0.2688|
| 15| 66266| 0.0411| 1.2044|
| 16| 125846| 0.0633| 0.4907|
| 17| 223654| 0.105| 1.4188|
| 18| 368778| 0.1929| 4.8654|
+---------+----------+-------------+------------------+
Analysis
The objective of this query was to evaluate the steller technique’s depandibity at various star magnitudes. Before entering to the machine learning part, it is essetional to indentify the “Signal-to-Noice” ratios.
- Binning Strategy: We grouped stars by their Apparent Magnitude (
phot_g_mean_mag) into integer bins (e.g., Magnitude 10, 11, 12…). - For each bin we have clacualted:
- Average Absolute Error
AVG(parallax_error) - Average Relative Error
AVG(parallax_error / parallax)
- Average Absolute Error
- Statistics:
- Bright stars - (Mag < 13 ) High photon counts result in high-precision centroids, leading to low parallax error.
- Dim stars - (Mag > 13 )Lower signal-to-noise ratios should result in exponentially increasing errors.
Visualization
Code
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
# 1. Convert Spark Result to Pandas
pdf_quality = quality_results.toPandas()
# 2. Setup the Plot (Dual Axis)
# We want to show Star Count (Bars) AND Error Rate (Line) on the same chart
fig, ax1 = plt.subplots(figsize=(12, 6))
# 3. Plot A: Star Count (The Histogram) on Left Axis
# This shows where most of our data lives
sns.barplot(
data=pdf_quality,
x='magnitude',
y='star_count',
color='cornflowerblue',
alpha=0.3,
ax=ax1,
label='Star Count'
)
ax1.set_ylabel('Number of Stars (Log Scale)', color='blue')
ax1.tick_params(axis='y', labelcolor='blue')
ax1.set_yscale('log') # Log scale because star counts vary wildly (1 to 300,000)
# 4. Plot B: The Parallax Error (AVG Error) on Right Axis
ax2 = ax1.twinx()
sns.lineplot(
data=pdf_quality,
x=ax1.get_xticks(), # Align line with bars
y='avg_raw_error', # AVG(parallax_error)
color='red',
marker='o',
linewidth=2,
ax=ax2,
label='Avg Parallax Error (mas)'
)
ax2.set_ylabel('Avg Parallax Error (milliarcseconds)', color='red')
ax2.tick_params(axis='y', labelcolor='red')
# 5. Titles and Layout
plt.title("Data Quality Audit (Error Rate vs. Brightness)", fontsize=14)
ax1.set_xlabel("Apparent Magnitude (Lower = Brighter)", fontsize=12)
plt.grid(True, linestyle=':', alpha=0.5)
plt.tight_layout()
plt.show()2.3 Jayrup (Exotic Star Hunting)
Finding rare and interesting stellar objects
2.3.1 White Dwarf Candidates
To find White Dwarfs, which are the hot, dense cores of dead stars. They are very hot but very dim. They are the remains of dead stars that have not yet completely cooled down, they start out bright-blue (top-right) and as they age, they turn dim-red (bottom-right). Scientists use this to estimate the age of the universe as their cooling down rate is fairly stable.
Importing Libraries
Code
import numpy as np
import seaborn as sns
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.collections import LineCollection
# 1. Setup Spark (Use existing data)
spark = SparkSession.builder \
.appName("WhiteDwarf_Hunter") \
.config("spark.driver.memory", "4g") \
.getOrCreate()Loading Data
Color-Magnitude Method
We’re isolating white dwarfs within 100 parsecs using Gaia’s precise astrometry and photometry. White dwarfs occupy a distinct region in the HR diagram:
- Faint but not too faint to avoid noisy detections (\(M_G\) \(10–15\))
- Blue-to-yellow colors where the cooling sequence is densest (\(BP–RP < 1.0\))
- Excluding outliers to remove rare hot subdwarfs (\(BP–RP ≥ -0.5\))
We calculate the Absolute magnitude \(M_G\) with
\[ M_G = m + 5 \log_{10}(parallax) - 10 \]
Code
# Find the White Dwarfs candidates using Color-Magnitude method
df_wd = spark.sql("""
WITH candidates AS (
SELECT *,
phot_g_mean_mag + 5 * LOG10(parallax) - 10 AS absolute_magnitude
FROM gaia
WHERE parallax > 0
AND bp_rp >= -0.5
AND bp_rp < 1.0
AND phot_g_mean_mag IS NOT NULL
AND bp_rp IS NOT NULL
)
SELECT *,
CASE
WHEN absolute_magnitude > 10 AND absolute_magnitude < 15
THEN 'White Dwarf'
ELSE 'Main Sequence / Other'
END AS type
FROM candidates
""")
# Trigger computation and cache
df_wd.cache().count()
# Count WDs
wd_count = df_wd.filter(col("type") == "White Dwarf").count()
print(f">> FOUND: {wd_count} White Dwarfs using Color-Magnitude method.")>> FOUND: 12580 White Dwarfs using Color-Magnitude method.
Code
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from pyspark.sql.functions import col
# === 1. OPTIMIZED DATA LOADING ===
# Only pull necessary columns + filter in Spark
wd_pandas = (
df_wd.filter(col("type") == "White Dwarf")
.select("bp_rp", "absolute_magnitude")
.filter(
(col("bp_rp").between(-0.5, 1.5)) &
(col("absolute_magnitude").between(8, 16)) &
col("bp_rp").isNotNull() &
col("absolute_magnitude").isNotNull()
)
.toPandas()
)
spark.stop()
# === 2. VECTORIZED DENSITY CALCULATION ===
x = wd_pandas['bp_rp'].values
y = wd_pandas['absolute_magnitude'].values
# Precompute bin edges
x_edges = np.linspace(-0.5, 1.5, 151)
y_edges = np.linspace(8, 16, 151)
# Compute histogram
H, _, _ = np.histogram2d(x, y, bins=[x_edges, y_edges])
# Digitize using precomputed edges
x_inds = np.digitize(x, x_edges[1:-1])
y_inds = np.digitize(y, y_edges[1:-1])
# Get densities directly from histogram array
z = H[x_inds, y_inds]
# === 3. SORTING OPTIMIZATION ===
# Use argsort on density (z) - but only if >10k points
if len(z) > 10000:
idx = np.argpartition(z, -10000)[-10000:] # Keep only top 10k densest points
x_sorted, y_sorted, z_sorted = x[idx], y[idx], z[idx]
else:
idx = z.argsort()
x_sorted, y_sorted, z_sorted = x[idx], y[idx], z[idx]
# === 4. PLOTTING ===
plt.figure(dpi=120)
# Use scatter with pre-sorted points (densest on top)
sc = plt.scatter(
x_sorted, y_sorted,
c=z_sorted,
s=1.5,
cmap='gist_heat',
norm=colors.LogNorm(vmin=1, vmax=np.max(H)), # Precomputed vmax
alpha=0.95,
edgecolors='none',
rasterized=True # converts to bitmap for huge speedup
)
# === 5. ASTRONOMY POLISH ===
ax = plt.gca()
ax.invert_yaxis()
ax.set_title("Gaia DR3: White Dwarf Cooling Sequence", fontsize=14)
ax.set_xlabel("Gaia BP–RP colour", fontsize=12)
ax.set_ylabel("Gaia G absolute magnitude", fontsize=12)
# Colorbar with precomputed norm
cbar = plt.colorbar(sc, pad=0.02)
cbar.set_label('Stars per bin (log scale)', rotation=270, labelpad=20)
ax.grid(True, alpha=0.2, linewidth=0.5)
plt.tight_layout(pad=1.5) # Faster than default layout
plt.show()
plt.close() # Free memory immediatelyWhat This Plot Reveals
This is the white dwarf cooling sequence the evolutionary path of dead stars in our cosmic neighborhood. Here’s what it tells us:
- The Diagonal Band
- White dwarfs start hot and blue (top-left: \(BP-RP ≈ -0.3, M_G ≈ 10\))
- They cool and redden over billions of years, moving down and right (bottom-right: \(BP-RP ≈ 1.0, M_G ≈ 15\))
- White dwarfs start hot and blue (top-left: \(BP-RP ≈ -0.3, M_G ≈ 10\))
- The Color Gradient
- Bright orange/white regions: High density of white dwarfs (common evolutionary stages)
- Dark red regions: Fewer white dwarfs (rare or short-lived phases)
- Bright orange/white regions: High density of white dwarfs (common evolutionary stages)
- The Smooth Curve
- This sequence is a stellar “fossil record”—it reveals how long white dwarfs have been cooling
- The gap at top-left (BP-RP < -0.2) shows very hot white dwarfs are rare (they cool quickly)
- The smooth curve confirms white dwarfs cool predictably—like cosmic thermometers
- This sequence is a stellar “fossil record”—it reveals how long white dwarfs have been cooling
- The “Bifurcation”
- The population is segregated : If we look closely at the diagonal band, it isn’t a single smear; it’s split into two distinct, parallel “ridges” or tracks. White Dwarfs are not a homogeneous group. This split typically represents a difference in atmospheric composition.
- Track A (Top/Blue ridge): Likely stars with Hydrogen-rich atmospheres (DA white dwarfs). Hydrogen is lighter and more opaque, acting as a blanket that keeps heat in differently than Helium.
- Track B (Bottom/Red ridge): Likely stars with Helium-rich atmospheres (DB white dwarfs).
- The population is segregated : If we look closely at the diagonal band, it isn’t a single smear; it’s split into two distinct, parallel “ridges” or tracks. White Dwarfs are not a homogeneous group. This split typically represents a difference in atmospheric composition.
Key insight: The densest part (bright orange) is where most white dwarfs “spend” their lives—proving they cool slowly over billions of years. This plot is why astronomers call white dwarfs “cosmic clocks” for measuring the age of our galaxy.
2.3.2 Red Giant Candidates
Red Giants are old, dying stars. They are very cool but very bright, hence they are on the top-right of the HR-diagram.
Code
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
# 1. Setup Spark
spark = SparkSession.builder \
.appName("RedGiant_DensityPlot") \
.config("spark.driver.memory", "4g") \
.getOrCreate()
# Load only necessary columns
df = spark.read.parquet("../data/gaia_survey.parquet")
# create temp view
df.createOrReplaceTempView("gaia_source")We use two seperate subqueries for the visualization. One to get the general idea and show the relative postion of the red-giants on the HR-daigram and the other to zoom it so we can analyze them better.
Code
df_all = spark.sql("""
SELECT
bp_rp,
phot_g_mean_mag + 5 * LOG10(parallax) - 10 AS absolute_magnitude
FROM gaia_source
WHERE parallax > 0
AND parallax/parallax_error > 10
AND bp_rp BETWEEN -0.5 AND 3.5
-- AND (phot_g_mean_mag + 5 * LOG10(parallax) - 25) < 12
""")
df_rg = spark.sql("""
SELECT
bp_rp,
phot_g_mean_mag + 5 * LOG10(parallax) - 10 AS absolute_magnitude,
parallax/parallax_error AS parallax_snr
FROM gaia_source
WHERE parallax > 0
AND parallax/parallax_error > 10 -- Good parallax quality
AND bp_rp BETWEEN 0.7 AND 2.5 -- Focus on RGB color range
AND phot_g_mean_mag + 5 * LOG10(parallax) - 10 < 4.0 -- Bright giants
AND bp_rp IS NOT NULL
AND phot_g_mean_mag IS NOT NULL
""")
# Convert to pandas for plotting
all_pandas = df_all.toPandas()
rg_pandas = df_rg.toPandas()
spark.stop()Code
# print no. of stars found
# print(f"Found {len(all_pandas)} stars in the survey.")
hb = plt.hexbin(
x=all_pandas['bp_rp'], y=all_pandas['absolute_magnitude'],
gridsize=500, # High res = no blocky look
extent=[-0.5, 3.5, -5, 12], # Fixed window for consistency
norm=colors.LogNorm(), # Essential: Compresses the dynamic range
cmap='inferno', # Perceptually uniform (Blue/Black -> Red -> Yellow)
mincnt=1 # Don't plot empty space
)
# Anotations
ax = plt.gca()
ax.invert_yaxis()
# A. The Main Sequence (High Density)
ax.text(1.6, 10.0, 'Main Sequence', color='purple', fontsize=10,
ha='right', rotation=-30)
# C. The True Evolutionary Path (Sub-Giant Branch)
# This arrow follows the curve, not a straight line
# Coordinates: Turn-off point -> Base of RGB
ax.annotate('', xy=(1.0, 0), xytext=(0.2,- 0.2),
arrowprops=dict(arrowstyle='->', lw=2, color='cyan', connectionstyle="arc3,rad=-0.2"))
ax.text(0,0, 'Giant\nBranch', color='purple', fontsize=9, ha='center')
# Polish
ax.set_xlabel("Gaia BP–RP colour", fontsize=12)
ax.set_ylabel("Absolute Magnitude ($M_G$)", fontsize=12)
# Colorbar
cb = plt.colorbar(hb, pad=0.02)
cb.set_label('Star Density (Log Scale)', rotation=270, labelpad=20)
plt.tight_layout()
plt.show()
# print no. of stars found
# print(f"Found {len(rg_pandas)} stars in the survey.")
hb = plt.hexbin(
x=rg_pandas['bp_rp'], y=rg_pandas['absolute_magnitude'],
gridsize=500, # High res = no blocky look
# extent=[-0.5, 3.5, -5, 12], # Fixed window for consistency
norm=colors.LogNorm(), # Essential: Compresses the dynamic range
cmap='inferno', # Perceptually uniform (Blue/Black -> Red -> Yellow)
mincnt=1 # Don't plot empty space
)
# Anotations
ax = plt.gca()
ax.invert_yaxis()
# A. The Main Sequence (High Density)
# ax.text(1.6, 3, 'Main Sequence', color='white', fontsize=10,
# ha='right', rotation=-30)
plt.text(0.8, 3.5, 'Main Sequence',
color='white', fontsize=8, fontweight='bold',
bbox=dict(facecolor='none', edgecolor='white', alpha=0.5))
# C. The True Evolutionary Path (Sub-Giant Branch)
# This arrow follows the curve, not a straight line
# Coordinates: Turn-off point -> Base of RGB
# ax.annotate('', xy=(1.0, 0), xytext=(0.2,- 0.2),
# arrowprops=dict(arrowstyle='->', lw=2, color='cyan', connectionstyle="arc3,rad=-0.2"))
# ax.text(0,0, 'Giant\nBranch', color='purple', fontsize=9, ha='center')
# Polish
ax.set_xlabel("Gaia BP–RP colour", fontsize=12)
ax.set_ylabel("Absolute Magnitude ($M_G$)", fontsize=12)
# Colorbar
cb = plt.colorbar(hb, pad=0.02)
cb.set_label('Star Density (Log Scale)', rotation=270, labelpad=20)
plt.tight_layout()
plt.show()What This Plot Reveals
This is the Red Giant Branch (RGB), the evolutionary path of aging, low- to intermediate-mass stars as they exhaust hydrogen in their cores and begin shell burning. Here’s what it tells us:
- The Bright, Red Clump
- Red giants are cool (\(BP–RP ≈ 1.0–2.5\)) but extremely luminous (\(M_G < 0\)), placing them in the upper-right of the HR diagram.
- The densest region—the Red Clump (\(around M_G ≈ −1, BP–RP ≈ 1.3\))—marks stars stably burning helium in their cores. This acts as a “standard candle” for distance measurements.
- Red giants are cool (\(BP–RP ≈ 1.0–2.5\)) but extremely luminous (\(M_G < 0\)), placing them in the upper-right of the HR diagram.
- The Ascending Giant Branch
- Stars move upward (brighter) and slightly redder as their outer envelopes expand after leaving the Main Sequence.
- The smooth, curved track reflects predictable stellar evolution governed by mass and composition.
- Stars move upward (brighter) and slightly redder as their outer envelopes expand after leaving the Main Sequence.
- Low-Mass vs. Upper RGB
- At fainter magnitudes (M_G ≈ 2–4), we see lower-mass giants just starting their ascent.
- The brightest, reddest stars (M_G < −1) are near the tip of the RGB, where helium ignition occurs in a dramatic flash for low-mass stars.
- At fainter magnitudes (M_G ≈ 2–4), we see lower-mass giants just starting their ascent.
- Why It Matters
- The RGB is a stellar aging sequence: its shape and density reveal the star formation history of our galactic neighborhood.
- Because red giants are bright, they’re visible across vast distances—making them key tracers of galactic structure.
- The RGB is a stellar aging sequence: its shape and density reveal the star formation history of our galactic neighborhood.
TLDR: this plot shows stars in their retirement, glowing brightly as they near the end of their lives—before shedding their outer layers to become white dwarfs.
2.3.3 Co-Moving Pair Search (Binary Candidates)
Binary stars are gravitationally bound systems sharing common motion through space. We can detect them by finding pairs with:
- Similar positions (angular proximity)
- Similar distances (parallax values)
- Similar proper motions (space velocity vectors)
Code
import pyspark.sql.functions as F
# initialize spark session
spark = SparkSession.builder \
.appName("Co-Moving_PairSearch") \
.config("spark.driver.memory", "4g") \
.getOrCreate()
# Load your 100pc data
df = spark.read.parquet("../data/gaia_100pc.parquet")
# Add coarse sky bins (1° × 1°) - simple but effective
df = df.withColumn("ra_bin", F.floor(F.col("ra") / 1)) \
.withColumn("dec_bin", F.floor(F.col("dec") / 1))
df.createOrReplaceTempView("stars")Code
pairs = spark.sql("""
SELECT A.source_id as id1, B.source_id as id2,
A.ra, A.dec, A.parallax,
A.pmra, A.pmdec,
B.ra as ra2, B.dec as dec2, B.parallax as plx2,
B.pmra as pmra2, B.pmdec as pmdec2
FROM stars A JOIN stars B
ON A.ra_bin = B.ra_bin AND A.dec_bin = B.dec_bin
AND A.source_id < B.source_id
WHERE ABS(A.parallax - B.parallax) < 1 -- Same distance (within 1 mas)
AND ABS(A.pmra - B.pmra) < 1 -- Same motion
AND ABS(A.pmdec - B.pmdec) < 1
""")Code
# Distance to each star (parsecs)
pairs = pairs.withColumn("dist_pc", 1000 / F.col("parallax"))
# Angular separation (degrees) and physical separation (AU)
pairs = pairs.withColumn(
"ang_sep_deg",
F.degrees(F.acos(F.sin(F.radians("dec")) * F.sin(F.radians("dec2")) +
F.cos(F.radians("dec")) * F.cos(F.radians("dec2")) *
F.cos(F.radians("ra") - F.radians("ra2"))))
).withColumn("sep_au", F.col("ang_sep_deg") * 3600 * F.col("dist_pc"))
# Keep only likely physical pairs
binaries = pairs.filter(F.col("sep_au") < 10000)
print(f">> Found: {binaries.count()} candidate binaries")[Stage 3:> (0 + 14) / 15]
>> Found: 5129 candidate binaries
[Stage 3:===============> (4 + 11) / 15]
Code
# Get photometry for plotting (join back to original data)
binaries_with_phot = binaries.alias("p").join(
df.select("source_id", "bp_rp", "phot_g_mean_mag").alias("phot"),
F.col("p.id1") == F.col("phot.source_id")
).join(
df.select("source_id", "bp_rp", "phot_g_mean_mag").alias("phot2"),
F.col("p.id2") == F.col("phot2.source_id")
).select(
"p.*",
F.col("phot.bp_rp").alias("color1"),
F.col("phot.phot_g_mean_mag").alias("mag1"),
F.col("phot2.bp_rp").alias("color2"),
F.col("phot2.phot_g_mean_mag").alias("mag2")
)
# Calculate absolute magnitudes
plot_df = binaries_with_phot.toPandas()
plot_df['abs_mag1'] = plot_df['mag1'] + 5*np.log10(plot_df['parallax']) - 10
plot_df['abs_mag2'] = plot_df['mag2'] + 5*np.log10(plot_df['plx2']) - 10[Stage 10:> (0 + 14) / 15][Stage 10:===> (1 + 14) / 15]
Code
# Analyze the distribution
print("\n=== Binary Separation Distribution ===")
print(f"Median separation: {plot_df['sep_au'].median():.1f} AU")
print(f"Mean separation: {plot_df['sep_au'].mean():.1f} AU")
print(f"Closest pair: {plot_df['sep_au'].min():.1f} AU")
print(f"Widest pair: {plot_df['sep_au'].max():.1f} AU")
# Classification by separation
plot_df['binary_type'] = np.select([
plot_df['sep_au'] < 100,
plot_df['sep_au'] < 1000,
plot_df['sep_au'] < 10000
], ['Close', 'Intermediate', 'Wide'], 'Very Wide')
=== Binary Separation Distribution ===
Median separation: 1685.8 AU
Mean separation: 2738.5 AU
Closest pair: 31.2 AU
Widest pair: 9994.3 AU
Code
# Plot: Sample 100 pairs for clarity
sample = plot_df.sample(n=min(100, len(plot_df)))
plt.gca().invert_yaxis()
plt.xlabel("Gaia BP-RP Color")
plt.ylabel("Absolute Magnitude ($M_G$)")
# Draw connecting lines
lines = [[(r['color1'], r['abs_mag1']), (r['color2'], r['abs_mag2'])]
for _, r in sample.iterrows()]
lc = LineCollection(lines, colors='gray', alpha=0.3, linewidths=0.5)
plt.gca().add_collection(lc)
# Plot the stars
plt.scatter(sample['color1'], sample['abs_mag1'], s=30, c='skyblue', label='Star A')
plt.scatter(sample['color2'], sample['abs_mag2'], s=30, c='orange', label='Star B')
plt.legend()
plt.tight_layout()
plt.show()
# histogram
plt.hist(plot_df['sep_au'], bins=50, log=True, color='steelblue', alpha=0.7)
plt.axvline(100, color='red', linestyle='--', label='Close (<100 AU)')
plt.axvline(1000, color='orange', linestyle='--', label='Intermediate')
# set labels
plt.xlabel("Physical Separation (AU)")
plt.ylabel("Count (log scale)")
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()3 ML-Models
In this section, we will look at the machine learning models that we have trained and evaluated on the Gaia Dataset.
3.1 Jasmi
3.2 Yogi (Staller Populations in Gaia DR3)
Using the gaia_survey dataset, the task shifts from getting data to implementing predictions. The objective is to construct, train, and evaluate classification models that can automatically categorize stellar objects based on their physical properties.
3.2.1 Purpose and Need for Implementation
- Differentiation of Stellar Objects: The main goal is to use observations to differentiate “nearby dwarfs from distant giants” and other star populations.
- Managing High-Dimensional Data: To make sure the data is appropriate for algorithmic division, the task applies modifications such log-scaling to meet the “massive range” of astronomical quantities (like parallax).
- Automation: When working with large-scale survey data that cannot be manually classified, the task automates the complicated procedure of feature collection, scaling, and classification by using a pipeline.
3.2.2 Model Insights
What it is about: A supervised classification workflow is implemented in this problem. After creating training labels using “ground truth” logic based on astronomy (Absolute Magnitude vs. Colour), it trains algorithms to predict such labels using only observable features.
The Workflow:
Feature Engineering: It constructs input features from raw data, such as calculating total_motion from proper motion vectors (\(pmra\) and \(pmdec\))
Label Generation: It creates a label column by applying specific physical cuts:
- White Dwarfs (Label 2.0): Defined as having Absolute Magnitude (\(M_G\)) \(> 10\)
- Red Giants (Label 1.0): Defined as having \(M_G < 3\) and Color (\(bp\_rp\)) \(> 1.0\)
- Main Sequence (Label 0.0): Everything else
- Algorithm Implementation: It implements two distinct algorithms to solve this problem:
- Random Forest Classifier: A non-linear model suited for complex decision boundaries.
- Logistic Regression: A linear model used for comparison, which includes feature scaling and elasticity regularization.
- What it is trying to predict: The models aim to predict the label (Star Type) of a star given only its observable features (bp_rp, phot_g_mean_mag, total_motion, parallax). It validates these predictions using Cross-Validation (3-Fold) to ensure the model is robust and “not just lucky”.
3.2.3 Phase 1: The Random Forest Model
Data Preparation & Features
Code
import os
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StandardScaler
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml import Pipeline
from pyspark.ml import PipelineModel
spark = SparkSession.builder \
.appName("StellarPopulation_RF_Advanced") \
.config("spark.driver.memory", "4g") \
.config("spark.executor.memory", "4g") \
.config("spark.memory.offHeap.enabled", "true") \
.config("spark.memory.offHeap.size", "1g") \
.config("spark.sql.shuffle.partitions", "200") \
.getOrCreate()
# 2. Load the Data (Correct spelling)
df = spark.read.parquet("../data/gaia_survey.parquet")
# A. Calculate Observables (Features)
# We add 'parallax' because it is an OBSERVABLE. It allows the model to differentiate
# nearby dwarfs from distant giants.
df = df.withColumn("total_motion", F.sqrt(F.col("pmra")**2 + F.col("pmdec")**2))
# Log-transform parallax to handle the massive range (scales data for better splits)
df = df.withColumn("log_parallax", F.log10(F.abs(F.col("parallax")) + 1e-6))
# 3. Clean data
df = df.dropna(subset=["phot_g_mean_mag", "bp_rp", "parallax", "total_motion"])
# Verify it worked
print(f"Data loaded and cleaned. Rows: {df.count()}")
df.printSchema()Data loaded and cleaned. Rows: 821788
root
|-- source_id: long (nullable = true)
|-- ra: double (nullable = true)
|-- dec: double (nullable = true)
|-- parallax: double (nullable = true)
|-- parallax_error: float (nullable = true)
|-- pmra: double (nullable = true)
|-- pmdec: double (nullable = true)
|-- phot_g_mean_mag: float (nullable = true)
|-- bp_rp: float (nullable = true)
|-- teff_gspphot: float (nullable = true)
|-- total_motion: double (nullable = true)
|-- log_parallax: double (nullable = true)
Handling Class Imbalance
A critical step in this code is addressing the fact that White Dwarfs are rare compared to Main Sequence stars.
- The Problem: Without correction, a model could achieve high accuracy by simply ignoring the rare White Dwarfs.
- The Solution: The code calculates Class Weights using the formula:\[Weight = \frac{Total Rows}{3.0 \times Class Count}\]This assigns higher weights to rare classes (White Dwarfs) to force the model to pay attention to them during training5.
Code
# --- A. Create Absolute Magnitude (M_G) for the Label Logic ---
# Distance d = 1000 / parallax (mas)
df = df.withColumn("distance", 1000 / F.col("parallax"))
df = df.withColumn("abs_mag", F.col("phot_g_mean_mag") - 5 * F.log10(F.col("distance")) + 5)
# --- B. Define the Classes (The "Ground Truth" Cuts) ---
df_labeled = df.withColumn("label",
F.when(F.col("abs_mag") > 10, 2.0) # White Dwarf
.when((F.col("abs_mag") < 3) & (F.col("bp_rp") > 1.0), 1.0) # Red Giant
.otherwise(0.0) # Main Sequence
)
# --- C. Handle Class Imbalance (Weighting) ---
# Calculate class counts
class_counts = df_labeled.groupBy("label").count().collect()
total_rows = df_labeled.count()
count_map = {row['label']: row['count'] for row in class_counts}
# Calculate weights: Weight = Total / (Number of Classes * Class Count)
# Weight = Total / (Num_Classes * Count)
class_weights = {k: total_rows / (3.0 * v) for k, v in count_map.items()}
print(f">> Class Weights: {class_weights}")
# Broadcast weights to a mapping column
mapping_expr = F.create_map([F.lit(x) for x in sum(class_weights.items(), ())])
df_weighted = df_labeled.withColumn("classWeight", mapping_expr.getItem(F.col("label")))
print("Class Weights Calculated:", class_weights)>> Class Weights: {0.0: 0.3913544547816612, 1.0: 2.4997657766178145, 2.0: 22.354278874925196}
Class Weights Calculated: {0.0: 0.3913544547816612, 1.0: 2.4997657766178145, 2.0: 22.354278874925196}
/home/jayrup/uni/big_data/group/.venv/lib/python3.13/site-packages/pyspark/sql/classic/column.py:359: FutureWarning:
A column as 'key' in getItem is deprecated as of Spark 3.0, and will not be supported in the future release. Use `column[key]` or `column.key` syntax instead.
Model Configuration & Tuning
Code
# Define Input Features (Observables only!)
feature_cols = ["bp_rp", "phot_g_mean_mag", "total_motion", "parallax", "log_parallax"]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
# Initialize Random Forest
rf = RandomForestClassifier(
labelCol="label",
featuresCol="features",
weightCol="classWeight",
seed=42,
subsamplingRate=0.7,
featureSubsetStrategy="sqrt"
)
# Build Pipeline
pipeline = Pipeline(stages=[assembler, rf])
# Parameter Grid for Tuning
paramGrid = ParamGridBuilder() \
.addGrid(rf.numTrees, [30, 50]) \
.addGrid(rf.maxDepth, [8, 12]) \
.build()
# Evaluator (Focus on Weighted F1 to balance all classes)
evaluator = MulticlassClassificationEvaluator(
labelCol="label", predictionCol="prediction", metricName="f1")
# Cross Validator (3-Fold)
# It ensures the model is robust and not just lucky.
cv = CrossValidator(
estimator=pipeline,
estimatorParamMaps=paramGrid,
evaluator=evaluator,
numFolds=3,
parallelism=2 # Train 2 models in parallel if memory allows
)
# 5. Train/Test Split
train_data, test_data = df_weighted.randomSplit([0.8, 0.2], seed=42)
print(f"Training Count: {train_data.count()} | Test Count: {test_data.count()}")[Stage 10:====================================================> (13 + 1) / 14]
Training Count: 657695 | Test Count: 164093
Training & Saving
Code
# Define the path
model_path = "stellar_classifier_rf_v1"
if not os.path.exists(model_path):
print(">> Model not found. Training a new model...")
print(">> Starting Cross-Validation Training (This may take 5-10 mins)...")
# 1. Train
cv_model = cv.fit(train_data)
# 2. Extract the Best Model (The winner)
best_model = cv_model.bestModel
# 3. Save ONLY the Best Model (Lighter and standard practice)
best_model.write().overwrite().save(model_path)
print(f"Model saved to {model_path}")
else:
print(">> Model found. Loading saved model...")
# 4. Load as PipelineModel
best_model = PipelineModel.load(model_path)
print(">> Model loaded.")
# --- Access the RF Stage for Parameters ---
# The Random Forest is the last stage in the pipeline (index -1)
best_rf_model = best_model.stages[-1]
print(f"\n>> Best Model Parameters:")
print(f" Num Trees: {best_rf_model.getNumTrees}")
print(f" Max Depth: {best_rf_model.getOrDefault('maxDepth')}")
# --- Evaluation ---
predictions = best_model.transform(test_data)
acc_eval = MulticlassClassificationEvaluator(metricName="accuracy")
f1_eval = MulticlassClassificationEvaluator(metricName="f1")
prec_eval = MulticlassClassificationEvaluator(metricName="weightedPrecision")
print("\n=== FINAL MODEL EVALUATION ===")
print(f"Accuracy: {acc_eval.evaluate(predictions):.2%}")
print(f"F1 Score: {f1_eval.evaluate(predictions):.2%}")
print(f"Precision: {prec_eval.evaluate(predictions):.2%}")>> Model found. Loading saved model...
>> Model loaded.
>> Best Model Parameters:
Num Trees: 30
Max Depth: 12
=== FINAL MODEL EVALUATION ===
25/12/17 07:39:40 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
[Stage 29:====================================================> (13 + 1) / 14] 25/12/17 07:39:42 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
Accuracy: 98.48%
[Stage 31:====================================================> (13 + 1) / 14] 25/12/17 07:39:44 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
F1 Score: 98.51%
[Stage 33:====================================================> (13 + 1) / 14]
Precision: 98.61%
Confusion Matrix
Code
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
# 1. Collect Predictions
print(">> Collecting predictions to driver for visualization...")
y_true = predictions.select("label").toPandas()
y_pred = predictions.select("prediction").toPandas()
# 2. Calculate Raw and Normalized Matrices
cm = confusion_matrix(y_true, y_pred)
labels = ["Main Seq", "Red Giant", "White Dwarf"]
# Normalize row-wise: Divides count by the total stars in that TRUE class
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
# 3. Plot the Heatmap
sns.heatmap(
cm_normalized,
annot=True,
fmt='.1%',
cmap='Greens',
xticklabels=labels,
yticklabels=labels,
cbar_kws={'label': 'Recall (True Positive Rate)'},
vmin=0, vmax=1 # Ensures color scale is fixed from 0% to 100%
)
plt.ylabel('Actual Star Type', fontsize=14)
plt.xlabel('Predicted Star Type', fontsize=14)
plt.tight_layout()
plt.show()>> Collecting predictions to driver for visualization...
25/12/17 07:39:46 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
[Stage 36:====================================================> (13 + 1) / 14]
Feature Importance
Code
=== Feature Importance ===
log_parallax: 0.4239
parallax: 0.3025
phot_g_mean_mag: 0.2037
bp_rp: 0.0660
total_motion: 0.0038
Visualizing Predictions
Code
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
# 1. Setup Data
plot_data = predictions.select("bp_rp", "abs_mag", "prediction").sample(False, 0.1).toPandas()
# 2. Map Predictions to Names
label_map = {0.0: "Main Sequence", 1.0: "Red Giant", 2.0: "White Dwarf"}
plot_data['Star Type'] = plot_data['prediction'].map(label_map)
# 3. Define "Academic" Palette (High Contrast for White Background)
academic_palette = {
"White Dwarf": "#9b59b6",
"Main Sequence": "#1abc9c",
"Red Giant": "#d35400"
}
# 4. Create Plot with "Seaborn Whitegrid" Style
plt.style.use('seaborn-v0_8-whitegrid')
ax = plt.gca()
order = ["Main Sequence", "Red Giant", "White Dwarf"]
for star_type in order:
subset = plot_data[plot_data['Star Type'] == star_type]
ax.scatter(
subset['bp_rp'],
subset['abs_mag'],
c=academic_palette[star_type],
s=5,
alpha=0.4,
edgecolor='none',
label=star_type
)
# 6. Professional Aesthetics
ax.invert_yaxis() # Standard Astronomy Convention
ax.set_title("Stellar Populations in Gaia DR3 (Predicted)", fontsize=16, weight='bold', pad=15)
ax.set_xlabel("Color Index ($G_{BP} - G_{RP}$) [mag]", fontsize=14)
ax.set_ylabel("Absolute Magnitude ($M_G$) [mag]", fontsize=14)
# Tick Customization
ax.tick_params(axis='both', which='major', labelsize=12)
# Legend (Boxed, Top Right, like the paper)
legend = ax.legend(
title='Classification',
fontsize=11,
title_fontsize=12,
loc='upper right',
frameon=True,
fancybox=False, # Square corners like the paper
edgecolor='black',
framealpha=1
)
plt.tight_layout()
plt.show()25/12/17 07:39:48 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
[Stage 37:====================================================> (13 + 1) / 14]
3.2.4 Phase 2: Logistic Regression
Model setup
Code
# 1. Feature Scaling (Crucial for Logistic Regression)
# This ensures 'total_motion' (large values) doesn't drown out 'bp_rp' (small values).
scaler = StandardScaler(
inputCol="features",
outputCol="scaled_features",
withStd=True,
withMean=True
)
# 2. Define the Estimator
# - family="multinomial": Explicitly tells Spark to handle 3 classes.
# - weightCol="classWeight": Uses the same weights as RF to handle the class imbalance.
lr = LogisticRegression(
labelCol="label",
featuresCol="scaled_features",
weightCol="classWeight",
family="multinomial",
maxIter=100
)Model Configuration & Tuning
Code
# 3. Build the Pipeline
# Pipeline flow: Raw Data -> Vector -> Scaled Vector -> Logistic Regression
pipeline_lr = Pipeline(stages=[assembler, scaler, lr])
# 4. Create Parameter Grid (Hyperparameter Tuning)
# - regParam: Controls regularization strength (prevents overfitting).
paramGrid_lr = ParamGridBuilder() \
.addGrid(lr.regParam, [0.01, 0.1]) \
.addGrid(lr.elasticNetParam, [0.0, 0.5]) \
.build()
# 5. Cross-Validation Setup
# We use the same 3-fold strategy as the Random Forest for consistency.
cv_lr = CrossValidator(
estimator=pipeline_lr,
estimatorParamMaps=paramGrid_lr,
evaluator=evaluator, # Reusing the F1 evaluator from previous model
numFolds=3,
parallelism=2
)Model Training & Evaluation
Code
# 6. Train the Model
print(">> Training Advanced Logistic Regression (with Scaling & CV)...")
cv_model_lr = cv_lr.fit(train_data)
best_model_lr = cv_model_lr.bestModel
print(">> Training Complete.")
# --- Save the Best LR Model ---
model_path_lr = "stellar_classifier_lr_v1"
best_model_lr.write().overwrite().save(model_path_lr)
print(f">> Model saved to {model_path_lr}")
# --- Metrics Evaluation ---
predictions_lr = best_model_lr.transform(test_data)
acc_eval = MulticlassClassificationEvaluator(metricName="accuracy")
f1_eval = MulticlassClassificationEvaluator(metricName="f1")
print("\n=== LOGISTIC REGRESSION RESULTS ===")
print(f"Accuracy: {acc_eval.evaluate(predictions_lr):.2%}")
print(f"Weighted F1 Score: {f1_eval.evaluate(predictions_lr):.2%}")
# --- Advanced Analysis: Extract Coefficients ---
# This counts as "Explaining the algorithm" for your report.
best_lr_stage = best_model_lr.stages[-1] # Extract the LR stage from pipeline
print("\n>> Model Coefficients (Linear Weights):")
# Coefficients is a matrix: 3 classes x 5 features
coeff_matrix = best_lr_stage.coefficientMatrix
# Intercepts: 3 values (one per class)
intercepts = best_lr_stage.interceptVector
# Print coefficients for the White Dwarf class (Label 2.0)
# This helps explain what makes a star a "White Dwarf" according to the math.
wd_coeffs = coeff_matrix.toArray()[2]
print(f"White Dwarf Coefficients (vs Features): {wd_coeffs}")>> Training Advanced Logistic Regression (with Scaling & CV)...
[Stage 38:==============> (13 + 1) / 14][Stage 39:==============> (13 + 1) / 14] [Stage 46:==============> (13 + 1) / 14][Stage 48:==============> (13 + 1) / 14] [Stage 50:==============> (13 + 1) / 14][Stage 52:==============> (13 + 1) / 14] [Stage 235:===================================================> (13 + 1) / 14] [Stage 336:=============> (13 + 1) / 14][Stage 337:=============> (13 + 1) / 14] [Stage 650:=============> (13 + 1) / 14][Stage 651:=============> (13 + 1) / 14] [Stage 658:===================================================> (13 + 1) / 14] [Stage 662:===================================================> (13 + 1) / 14] [Stage 847:===================================================> (13 + 1) / 14] [Stage 940:===================================================> (13 + 1) / 14] [Stage 943:===================================================> (13 + 1) / 14] [Stage 945:===================================================> (13 + 1) / 14]
>> Training Complete.
>> Model saved to stellar_classifier_lr_v1
=== LOGISTIC REGRESSION RESULTS ===
[Stage 1067:==================================================> (13 + 1) / 14]
Accuracy: 92.02%
Weighted F1 Score: 92.92%
>> Model Coefficients (Linear Weights):
White Dwarf Coefficients (vs Features): [0.03140263 1.14112309 0.00354785 0.58773451 1.17785557]
[Stage 1069:==================================================> (13 + 1) / 14]
Visualizing Predictions
Code
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pyspark.sql import functions as F
# --- 1. Data Preparation ---
# print(">> Collecting data for Visualization...")
# Assuming 'predictions_lr' is your existing Spark DataFrame
df_viz = predictions_lr.select("bp_rp", "abs_mag", "label", "prediction") \
.sample(False, 0.2, seed=42) \
.toPandas()
# --- Logic for Plot A: Error Analysis ---
df_viz["is_correct"] = df_viz["label"] == df_viz["prediction"]
df_viz["Prediction Status"] = df_viz["is_correct"].map({True: "Correct", False: "Misclassified"})
# --- Logic for Plot B: Prediction Classes ---
label_map = {0.0: "Main Sequence", 1.0: "Red Giant", 2.0: "White Dwarf"}
df_viz["Predicted Star Type"] = df_viz["prediction"].map(label_map)
# Set Global Style
sns.set_style("whitegrid")
custom_red = "#e63946"
custom_med_blue = "#457b9d"
custom_dark_blue = "#1d3557"
# ==========================================
# PLOT 1: The Error Map
# ==========================================
fig1, ax1 = plt.subplots(figsize=(8, 6))
error_palette = {"Correct": custom_dark_blue, "Misclassified": custom_red}
sns.scatterplot(
data=df_viz,
x="bp_rp",
y="abs_mag",
hue="Prediction Status",
palette=error_palette,
s=15, # Increased size slightly since the plot is bigger now
alpha=0.5,
edgecolor=None,
ax=ax1
)
ax1.set_title("A. Error Analysis (Where did it fail?)", fontsize=14, weight='bold')
ax1.legend(loc="upper right", frameon=True, edgecolor='black', title="Accuracy", fontsize=10, title_fontsize=11)
# Formatting
ax1.invert_yaxis()
ax1.set_xlabel("Color Index ($G_{BP} - G_{RP}$)", fontsize=12)
ax1.set_ylabel("Absolute Magnitude ($M_G$)", fontsize=12)
ax1.grid(True, linestyle='-', alpha=0.3)
ax1.tick_params(axis='both', which='major', labelsize=10)
plt.tight_layout()
plt.show()
# ==========================================
# PLOT 2: The Prediction Map
# ==========================================
fig2, ax2 = plt.subplots(figsize=(8, 6))
model_palette = {
"Red Giant": custom_red,
"Main Sequence": custom_med_blue,
"White Dwarf": custom_dark_blue
}
sns.scatterplot(
data=df_viz,
x="bp_rp",
y="abs_mag",
hue="Predicted Star Type",
palette=model_palette,
hue_order=["Main Sequence", "Red Giant", "White Dwarf"],
s=15,
alpha=0.5,
edgecolor=None,
ax=ax2
)
ax2.set_title("B. Model Predictions (Linear Boundaries)", fontsize=14, weight='bold')
ax2.legend(loc="upper right", frameon=True, edgecolor='black', title="Predicted Class", fontsize=10, title_fontsize=11)
# Formatting
ax2.invert_yaxis()
ax2.set_xlabel("Color Index ($G_{BP} - G_{RP}$)", fontsize=12)
ax2.set_ylabel("Absolute Magnitude ($M_G$)", fontsize=12)
ax2.grid(True, linestyle='-', alpha=0.3)
ax2.tick_params(axis='both', which='major', labelsize=10)
plt.tight_layout()
plt.show()3.2.5 Phase 3: Comparing Models
Code
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# --- 1. Define Evaluation Function ---
def get_metrics(predictions, model_name):
# Initialize evaluators
acc_eval = MulticlassClassificationEvaluator(metricName="accuracy")
f1_eval = MulticlassClassificationEvaluator(metricName="f1")
prec_eval = MulticlassClassificationEvaluator(metricName="weightedPrecision")
rec_eval = MulticlassClassificationEvaluator(metricName="weightedRecall")
# Calculate scores
return {
"Model": model_name,
"Accuracy": acc_eval.evaluate(predictions),
"F1 Score": f1_eval.evaluate(predictions),
"Precision": prec_eval.evaluate(predictions),
"Recall": rec_eval.evaluate(predictions)
}
# --- 2. Collect Data ---
print(">> Calculating metrics for comparison...")
# Assuming 'predictions' is from Random Forest and 'predictions_lr' is from Logistic Regression
metrics_rf = get_metrics(predictions, "Random Forest")
metrics_lr = get_metrics(predictions_lr, "Logistic Regression")
# Create DataFrame
df_metrics = pd.DataFrame([metrics_rf, metrics_lr])
# Melt for plotting (Long format)
df_melted = df_metrics.melt(id_vars="Model", var_name="Metric", value_name="Score")>> Calculating metrics for comparison...
25/12/17 07:40:46 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
[Stage 1072:==================================================> (13 + 1) / 14] 25/12/17 07:40:47 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
[Stage 1074:==================================================> (13 + 1) / 14] 25/12/17 07:40:48 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
[Stage 1076:==================================================> (13 + 1) / 14] 25/12/17 07:40:50 WARN DAGScheduler: Broadcasting large task binary with size 2.6 MiB
[Stage 1078:==================================================> (13 + 1) / 14]
Code
# --- 3. Create Comparison Visual ---
plt.figure(figsize=(10, 6))
sns.set_style("whitegrid")
# Custom Palette: Dark Blue vs Red
custom_palette = ["#1d3557", "#e63946"]
ax = sns.barplot(
data=df_melted,
x="Metric",
y="Score",
hue="Model",
palette=custom_palette,
edgecolor="black",
linewidth=0.8
)
# --- 4. Aesthetics ---
plt.title("Performance Showdown: Random Forest vs. Logistic Regression", fontsize=14, weight='bold', pad=15)
plt.ylabel("Score (0.0 - 1.0)", fontsize=12)
plt.xlabel("Evaluation Metric", fontsize=12)
plt.ylim(0.8, 1.0) # Zoom in to show differences clearly (Adjust if scores are lower)
plt.legend(title="Algorithm", loc="lower right", frameon=True, edgecolor='black')
# Add values on top of bars
for container in ax.containers:
ax.bar_label(container, fmt='%.3f', padding=3, fontsize=10)
plt.tight_layout()
plt.show()
# Print Table for Report
print("\n=== FINAL COMPARISON TABLE ===")
print(df_metrics.round(4))
=== FINAL COMPARISON TABLE ===
Model Accuracy F1 Score Precision Recall
0 Random Forest 0.9848 0.9851 0.9861 0.9848
1 Logistic Regression 0.9202 0.9292 0.9519 0.9202
Comparision Analysis
Conclusion: Which Model Performed Better?
The Random Forest Classifier is the superior model for this prediction goal.
While the Logistic Regression provides a useful baseline and helps explain linear relationships , the Random Forest is better option to the specific scientific nature of the data.
- Performance: As shown in the bar chart, the Random Forest model consistently achieves higher scores across F1 and Accuracy.
- Reasoning: The physical boundaries between star types (specifically Main Sequence vs. Red Giants) are curved, not straight. Random Forest can perform well on these complex.
- Weakness of Logistic Regression: The Logistic Regression model attempts to draw a straight line through the data. As seen in the “Error Map” previously, this linear boundary cuts through the curved Red Giant branch, leading to higher misclassification rates in that specific region.
3.3 Jayrup
The objective of this task is to predict stellar parallax using features derived from the Gaia dataset. Parallax is a continuous variable and therefore the task is formulated as a regression problem. Predicting parallax allows indirect estimation of stellar distance, which is a fundamental problem in astrophysics.
Understanding the data
The following features were selected based on their relevance and exploratory data analysis:
Total Proper Motion(\(\mu\)): Computed as \[ \mu = \sqrt{\text{pmra}^2 + \text{pmra}^2} \]
Proper motion provides kinematic information related to stellar distance.
Photometric G-band Magnitude (
phot_g_mean_mag): Represents apparent brightness, which is distance dependent.Colour Index (
bp_rp): Used as a proxy for stellar temperature and population type.Effective Temperature (
teff_gspphot): Provides additional physical context to distinguish between stellar populations.
Code
spark = SparkSession.builder \
.appName("Parallax_Prediction") \
.config("spark.driver.memory", "4g") \
.getOrCreate()
df = spark.read.parquet("../data/gaia_100pc.parquet")
# Show basic stats
df.select('pmra', 'pmdec', 'parallax').describe().show()
# Count total rows
print(f"Total rows: {df.count()}")25/12/17 07:40:53 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.
+-------+------------------+-------------------+------------------+
|summary| pmra| pmdec| parallax|
+-------+------------------+-------------------+------------------+
| count| 541958| 541958| 541958|
| mean|-3.085397329690918|-23.233364391692742|14.283010896409225|
| stddev| 94.7899342646246| 88.4285353154643|7.0432588218793475|
| min|-4406.469178827325|-5817.8001940492695|10.000005410606198|
| max| 6765.995136250774| 10362.394206546573| 768.0665391873573|
+-------+------------------+-------------------+------------------+
Total rows: 541958
Code
+-------+------------------+-------------------+-------------------+
|summary| pmra| pmdec| total_motion|
+-------+------------------+-------------------+-------------------+
| count| 541958| 541958| 541958|
| mean|-3.085397329690918|-23.233364391692742| 72.42738578220973|
| stddev| 94.7899342646246| 88.4285353154643| 110.037773135682|
| min|-4406.469178827325|-5817.8001940492695|0.00949718093624809|
| max| 6765.995136250774| 10362.394206546573| 10393.348722273944|
+-------+------------------+-------------------+-------------------+
Visualizing the data
Code
from pyspark.sql.functions import col, sqrt, mean, stddev
import seaborn as sns
import matplotlib.pyplot as plt
# First, let's calculate total motion
df = df.withColumn("total_motion", sqrt(col("pmra")**2 + col("pmdec")**2))
# Convert to Pandas for plotting
pdf = df.select("total_motion").toPandas()
# Plot using Seaborn or Matplotlib
# plt.figure(figsize=(10, 6))
sns.histplot(pdf["total_motion"], bins=100, kde=True)
plt.show()
import numpy as np
# Create 100 bins spaced logarithmically from the min to the max of your data
bins = np.logspace(np.log10(pdf["total_motion"].min()),
np.log10(pdf["total_motion"].max()),
100)
# plt.figure(figsize=(10, 6))
plt.hist(pdf["total_motion"], bins=bins)
plt.xscale('log')
plt.xlabel("Total Motion (Log Scale)")
plt.ylabel("Frequency")
plt.show()We can gain the following insights from the plots:
The linear plot (figure 1) exhibits a massive “right-skew,” compressing 99% of the data into the first bin. This causes the model to over-weight the few extreme outliers (high motion stars) while treating the vast majority of the dataset as having near-zero variance.
The log transformation (Figure 2) reveals that the underlying data distribution is actually bimodal (two distinct peaks). The linear plot mathematically hid this physical distinction between the background stars (first peak) and the nearby high-proper-motion stars (second peak).
Regression algorithms (especially Linear Regression) assume constant variance across the range of values. The raw proper motion spans several orders of magnitude (from 0.01 to 10,000+); applying a log transform stabilizes the variance, making the error metrics (RMSE) meaningful across the entire dataset rather than just for the fastest stars.
Code
Found 628 extreme outliers.
+-------------------+------------------+------------------+------------------+--------------+------------------+-------------------+---------------+---------+------------+------------------+
| source_id| ra| dec| parallax|parallax_error| pmra| pmdec|phot_g_mean_mag| bp_rp|teff_gspphot| total_motion|
+-------------------+------------------+------------------+------------------+--------------+------------------+-------------------+---------------+---------+------------+------------------+
|4472832130942575872|269.44850252543836| 4.739420051112412| 546.975939730948| 0.040116355|-801.5509783684709| 10362.394206546573| 8.1939745|2.8336968| 3099.6335|10393.348722273944|
|4810594479418041856| 77.9599373502188| -45.0438126993602|254.19859326384577| 0.016842743| 6491.223339061598| -5708.614150045243| 8.063552|2.0266457| 3451.8704| 8644.319287929779|
|4034171629042489088|178.26735320817272| 37.69282694689086|109.02963997046682| 0.019686269| 4002.654640989075|-5817.8001940492695| 6.1985016|1.0016494| 5043.2183| 7061.730897797727|
|6553614253923452800| 346.5039166796005| -35.8471642082214| 304.1353692001036| 0.01999573| 6765.995136250774| 1330.2852747179845| 6.5220323|2.0982852| 3376.0845|6895.5310959998305|
|2306965202564744064|1.3832841523481234|-37.36774402806293|230.09703402875448| 0.036182754| 5633.438087895326|-2334.7212726520424| 7.6824937| 2.186049| 3355.3533| 6098.077411047167|
|1872046609345556480| 316.7484792940004| 38.76386244649797|285.99494829578117| 0.05989728|4164.2086922846665| 3249.613883848584| 4.766713|1.4625897| 4353.7437| 5282.104166617755|
|3098328182579892096|122.99468256110134| 8.750401522062495|147.72184850183513| 0.094956644| 1069.811738087307| -5094.220103378359| 11.397289| 2.992114| NULL| 5205.341066310027|
|1872046574983497216| 316.753662752556| 38.75607277205679| 286.0053518616485| 0.028940246| 4105.976428209489| 3155.9416398273515| 5.4506445|1.7153406| 3889.6328| 5178.707373757288|
| 35227046884571776| 43.26964247679057| 16.86437381897744|260.98844068047276| 0.09342672|3429.0828268077694| -3805.54112273733| 12.263103|4.5285025| NULL| 5122.572817437822|
| 762815470562110464|165.83095967577933|35.948653032660104|392.75294543876464| 0.03206665|-580.0570872139048| -4776.588719443488| 6.551172|2.2156086| 3511.045| 4811.680165923527|
+-------------------+------------------+------------------+------------------+--------------+------------------+-------------------+---------------+---------+------------+------------------+
only showing top 10 rows
This confirms we are looking at real, high-quality data: - Barnard’s Star: The first outlier in your table (Source ID 4472832130942575872, total_motion ~10,393) is almost certainly Barnard’s Star, which has the highest proper motion of any known star. - The 100pc limit: Min parallax is ~10.0, which corresponds exactly to a distance of 100 parsecs (\(d = 1000/ \pi\)), matching your filename gaia_100pc.
Preprocessing
To reduce skewness and improve model stability, the following logarithmic transformations.
Code
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression, RandomForestRegressor
from pyspark.ml.regression import LinearRegressionModel, RandomForestRegressionModel
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql.functions import log10
# Add log columns to handle the huge range of motion
df_ml = df.withColumn("log_motion", log10(col("total_motion") + 1))
df_ml = df_ml.withColumn("log_parallax", log10(col("parallax") + 1))
df_ml = df_ml.na.drop(subset=["bp_rp", "phot_g_mean_mag", "total_motion","teff_gspphot"])
# Prepare Features (Using the bimodal motion and magnitude)
assembler = VectorAssembler(
inputCols=["log_motion", "phot_g_mean_mag","bp_rp"],
outputCol="features",
handleInvalid="skip"
)
data = assembler.transform(df_ml)
# Split into Training (80%) and Testing (20%)
train_df, test_df = data.randomSplit([0.8, 0.2], seed=42)
lr_path = "models/gaia_linear_regression"
rf_path = "models/gaia_random_forest"
if os.path.exists(lr_path):
# Load Model A: Linear Regression
lr_model = LinearRegressionModel.load(lr_path)
else:
# Train Model A: Linear Regression
lr = LinearRegression(featuresCol="features", labelCol="log_parallax")
lr_model = lr.fit(train_df)
lr_model.write().overwrite().save(lr_path)
print(f"Model saved to {lr_path}")
if os.path.exists(rf_path):
# Load Model B: Random Forest
rf_model = RandomForestRegressionModel.load(rf_path)
else:
# Train Model B: Random Forest
rf = RandomForestRegressor(
featuresCol="features",
labelCol="log_parallax",
numTrees=100,
maxDepth=12,
# minInstancesPerNode=20,
# featureSubsetStrategy="sqrt",
seed=42
)
rf_model = rf.fit(train_df)
rf_model.write().overwrite().save(rf_path)
print(f"Model saved to {rf_path}")
# 5. Get Predictions
lr_pred = lr_model.transform(test_df)
rf_pred = rf_model.transform(test_df)
# 6. Evaluate
eval_r2 = RegressionEvaluator(labelCol="log_parallax", predictionCol="prediction", metricName="r2")
eval_rmse = RegressionEvaluator(labelCol="log_parallax", predictionCol="prediction", metricName="rmse")
print("Linear Regression R2:", eval_r2.evaluate(lr_pred))
print("Random Forest R2: ", eval_r2.evaluate(rf_pred))
print("-" * 30)
print("Linear Regression RMSE:", eval_rmse.evaluate(lr_pred))
print("Random Forest RMSE: ", eval_rmse.evaluate(rf_pred))[Stage 1115:============================> (1 + 1) / 2]
Linear Regression R2: 0.6381499150350736
[Stage 1118:=================================================> (8 + 1) / 9]
Random Forest R2: 0.7552582025878004
------------------------------
Linear Regression RMSE: 0.0882892047246452
Random Forest RMSE: 0.07261015068300783
[Stage 1122:=================================================> (8 + 1) / 9]
Code
import matplotlib.pyplot as plt
import seaborn as sns
# Convert Spark DataFrame to Pandas for plotting
pdf = rf_pred.select("log_parallax", "prediction", "bp_rp").sample(fraction=0.05, seed=42).toPandas()
# Scatter plot: predicted vs true log_parallax
plt.figure(figsize=(8, 6))
scatter = plt.scatter(
pdf["log_parallax"],
pdf["prediction"],
c=pdf["bp_rp"], # color by bp_rp
cmap="viridis",
s=10,
alpha=0.6
)
plt.plot([pdf["log_parallax"].min(), pdf["log_parallax"].max()],
[pdf["log_parallax"].min(), pdf["log_parallax"].max()],
color="red", linestyle="--", label="y = x")
plt.xlabel("True log(Parallax)")
plt.ylabel("Predicted log(Parallax)")
plt.colorbar(scatter, label="bp_rp")
plt.legend()
plt.show()
pdf_lr = lr_pred.select("log_parallax", "prediction", "bp_rp").sample(fraction=0.05, seed=42).toPandas()
plt.figure(figsize=(8,6))
sns.scatterplot(
x=pdf_lr["log_parallax"], y=pdf_lr["prediction"],
hue=pdf_lr["bp_rp"], palette="coolwarm", alpha=0.6, s=10
)
plt.plot([pdf_lr["log_parallax"].min(), pdf_lr["log_parallax"].max()],
[pdf_lr["log_parallax"].min(), pdf_lr["log_parallax"].max()],
color="black", linestyle="--", label="y = x")
plt.xlabel("True log(Parallax)")
plt.ylabel("Predicted log(Parallax)")
plt.legend()
plt.show()